mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
OPT - Fix Softmax NaN in half precision mode (#17437)
This commit is contained in:
parent
9fe2403bc5
commit
d444edb3f6
@ -109,7 +109,6 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->OPT
|
||||
class OPTAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -212,9 +211,15 @@ class OPTAttention(nn.Module):
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
dtype_attn_weights = attn_weights.dtype
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
||||
if dtype_attn_weights == torch.float16:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
|
||||
else:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
@ -382,7 +387,7 @@ class OPTPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["OPTDecoderLayer"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder.version"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
@ -22,7 +22,7 @@ import unittest
|
||||
import timeout_decorator # noqa
|
||||
|
||||
from transformers import OPTConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -428,3 +428,25 @@ class OPTGenerationTest(unittest.TestCase):
|
||||
predicted_outputs += generated_string
|
||||
|
||||
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_batched_nan_fp16(self):
|
||||
# a bug manifested starting at models facebook/opt-1.3 and larger when running batched generations,
|
||||
# therefore not using a tiny model, but the smallest model the problem was seen with which is opt-1.3b.
|
||||
# please refer to this github thread: https://github.com/huggingface/transformers/pull/17437 for more details
|
||||
model_name = "facebook/opt-1.3b"
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_fast=False, padding_side="left")
|
||||
|
||||
model = OPTForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
|
||||
model = model.eval()
|
||||
|
||||
batch = tokenizer(["Who are you?", "Joe Biden is the president of"], padding=True, return_tensors="pt")
|
||||
|
||||
input_ids = batch["input_ids"].cuda()
|
||||
attention_mask = batch["attention_mask"].cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, attention_mask=attention_mask)
|
||||
self.assertFalse(
|
||||
torch.isnan(outputs.logits[0]).any().item()
|
||||
) # the first logits could contain NaNs if it fails
|
||||
|
Loading…
Reference in New Issue
Block a user