mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[InstructBlip
] Fix int8/fp4 issues (#24888)
* fix dtype issue * revert `.float()` * fix copies
This commit is contained in:
parent
3ec10e6c76
commit
a9e067a45c
@ -558,7 +558,6 @@ class InstructBlipVisionModel(InstructBlipPreTrainedModel):
|
|||||||
return self.embeddings
|
return self.embeddings
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerMultiHeadAttention with Blip2->InstructBlip
|
|
||||||
class InstructBlipQFormerMultiHeadAttention(nn.Module):
|
class InstructBlipQFormerMultiHeadAttention(nn.Module):
|
||||||
def __init__(self, config, is_cross_attention=False):
|
def __init__(self, config, is_cross_attention=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -659,13 +658,14 @@ class InstructBlipQFormerMultiHeadAttention(nn.Module):
|
|||||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||||
|
|
||||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||||
|
attention_scores_dtype = attention_scores.dtype
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||||
attention_scores = attention_scores + attention_mask
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
# Normalize the attention scores to probabilities.
|
# Normalize the attention scores to probabilities.
|
||||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
|
||||||
|
|
||||||
if is_cross_attention and self.save_attention:
|
if is_cross_attention and self.save_attention:
|
||||||
self.save_attention_map(attention_probs)
|
self.save_attention_map(attention_probs)
|
||||||
@ -1038,6 +1038,7 @@ class InstructBlipQFormerEmbeddings(nn.Module):
|
|||||||
else:
|
else:
|
||||||
embeddings = query_embeds
|
embeddings = query_embeds
|
||||||
|
|
||||||
|
embeddings = embeddings.to(self.layernorm.weight.dtype)
|
||||||
embeddings = self.layernorm(embeddings)
|
embeddings = self.layernorm(embeddings)
|
||||||
embeddings = self.dropout(embeddings)
|
embeddings = self.dropout(embeddings)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
Loading…
Reference in New Issue
Block a user