mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Clamping hidden state values to allow FP16 (#19229)
* Clamping hidden state values to allow FP16 * Reformating * Adding missing if condition * Update src/transformers/models/longt5/modeling_longt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/longt5/modeling_longt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/longt5/modeling_longt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Formating file Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
587d84b178
commit
971da2e6ec
@ -1199,6 +1199,11 @@ class LongT5Block(nn.Module):
|
||||
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
||||
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
||||
|
||||
# clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
|
||||
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
||||
if do_cross_attention:
|
||||
# the actual query length is unknown for cross attention
|
||||
@ -1221,6 +1226,11 @@ class LongT5Block(nn.Module):
|
||||
)
|
||||
hidden_states = cross_attention_outputs[0]
|
||||
|
||||
# clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
|
||||
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
# Combine self attn and cross attn key value states
|
||||
if present_key_value_state is not None:
|
||||
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
|
||||
@ -1231,6 +1241,11 @@ class LongT5Block(nn.Module):
|
||||
# Apply Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states)
|
||||
|
||||
# clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
|
||||
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if use_cache:
|
||||
|
Loading…
Reference in New Issue
Block a user