mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Make CLIP model could use new added tokens with meaningful pooling (#24777)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
d0154015f7
commit
eeaa9c016a
@ -701,6 +701,9 @@ class CLIPTextTransformer(nn.Module):
|
||||
self.encoder = CLIPEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
||||
def forward(
|
||||
@ -750,13 +753,26 @@ class CLIPTextTransformer(nn.Module):
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
||||
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
|
||||
.int()
|
||||
.argmax(dim=-1),
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
@ -487,6 +487,9 @@ class FlaxCLIPTextTransformer(nn.Module):
|
||||
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
|
||||
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = self.config.eos_token_id
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
@ -517,9 +520,18 @@ class FlaxCLIPTextTransformer(nn.Module):
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the EOS embedding (eos_token_id is the highest number in each sequence)
|
||||
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the EOS embedding (eos_token_id is the highest number in each sequence)
|
||||
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
|
||||
else:
|
||||
# (no need to cast from bool to int after comparing to `eos_token_id`)
|
||||
pooled_output = last_hidden_state[
|
||||
jnp.arange(last_hidden_state.shape[0]), (input_ids == self.eos_token_id).argmax(axis=-1)
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
@ -494,6 +494,9 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer):
|
||||
epsilon=config.layer_norm_eps, name="final_layer_norm"
|
||||
)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids: TFModelInputType,
|
||||
@ -530,14 +533,30 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer):
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.final_layer_norm(inputs=sequence_output)
|
||||
|
||||
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
pooled_output = tf.gather_nd(
|
||||
params=sequence_output,
|
||||
indices=tf.stack(
|
||||
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
|
||||
),
|
||||
)
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
pooled_output = tf.gather_nd(
|
||||
params=sequence_output,
|
||||
indices=tf.stack(
|
||||
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
|
||||
),
|
||||
)
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = tf.gather_nd(
|
||||
params=sequence_output,
|
||||
indices=tf.stack(
|
||||
values=(
|
||||
tf.range(input_shape[0], dtype=tf.int64),
|
||||
tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
|
||||
),
|
||||
axis=1,
|
||||
),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
@ -97,8 +97,8 @@ class CLIPSegTextConfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
bos_token_id=49406,
|
||||
eos_token_id=49407,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
@ -712,6 +712,9 @@ class CLIPSegTextTransformer(nn.Module):
|
||||
self.encoder = CLIPSegEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
@add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg
|
||||
@ -762,13 +765,26 @@ class CLIPSegTextTransformer(nn.Module):
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
||||
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
|
||||
.int()
|
||||
.argmax(dim=-1),
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
@ -106,8 +106,8 @@ class GroupViTTextConfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
bos_token_id=49406,
|
||||
eos_token_id=49407,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
@ -1095,6 +1095,9 @@ class GroupViTTextTransformer(nn.Module):
|
||||
self.encoder = GroupViTTextEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
@add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTTextConfig)
|
||||
def forward(
|
||||
@ -1144,13 +1147,26 @@ class GroupViTTextTransformer(nn.Module):
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
||||
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
|
||||
.int()
|
||||
.argmax(dim=-1),
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
@ -1002,6 +1002,9 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer):
|
||||
epsilon=config.layer_norm_eps, name="final_layer_norm"
|
||||
)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids: TFModelInputType,
|
||||
@ -1038,14 +1041,30 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer):
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.final_layer_norm(inputs=sequence_output)
|
||||
|
||||
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
pooled_output = tf.gather_nd(
|
||||
params=sequence_output,
|
||||
indices=tf.stack(
|
||||
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
|
||||
),
|
||||
)
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
pooled_output = tf.gather_nd(
|
||||
params=sequence_output,
|
||||
indices=tf.stack(
|
||||
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
|
||||
),
|
||||
)
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = tf.gather_nd(
|
||||
params=sequence_output,
|
||||
indices=tf.stack(
|
||||
values=(
|
||||
tf.range(input_shape[0], dtype=tf.int64),
|
||||
tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
|
||||
),
|
||||
axis=1,
|
||||
),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
Loading…
Reference in New Issue
Block a user