Make CLIP model could use new added tokens with meaningful pooling (#24777)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-07-17 20:35:20 +02:00 committed by GitHub
parent d0154015f7
commit eeaa9c016a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 142 additions and 44 deletions

View File

@ -701,6 +701,9 @@ class CLIPTextTransformer(nn.Module):
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward(
@ -750,13 +753,26 @@ class CLIPTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
.int()
.argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]

View File

@ -487,6 +487,9 @@ class FlaxCLIPTextTransformer(nn.Module):
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# For `pooled_output` computation
self.eos_token_id = self.config.eos_token_id
def __call__(
self,
input_ids,
@ -517,9 +520,18 @@ class FlaxCLIPTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the EOS embedding (eos_token_id is the highest number in each sequence)
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the EOS embedding (eos_token_id is the highest number in each sequence)
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
else:
# (no need to cast from bool to int after comparing to `eos_token_id`)
pooled_output = last_hidden_state[
jnp.arange(last_hidden_state.shape[0]), (input_ids == self.eos_token_id).argmax(axis=-1)
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]

View File

@ -494,6 +494,9 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer):
epsilon=config.layer_norm_eps, name="final_layer_norm"
)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
def call(
self,
input_ids: TFModelInputType,
@ -530,14 +533,30 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer):
sequence_output = encoder_outputs[0]
sequence_output = self.final_layer_norm(inputs=sequence_output)
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = tf.gather_nd(
params=sequence_output,
indices=tf.stack(
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
),
)
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = tf.gather_nd(
params=sequence_output,
indices=tf.stack(
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
),
)
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = tf.gather_nd(
params=sequence_output,
indices=tf.stack(
values=(
tf.range(input_shape[0], dtype=tf.int64),
tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
),
axis=1,
),
)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

View File

@ -97,8 +97,8 @@ class CLIPSegTextConfig(PretrainedConfig):
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
bos_token_id=49406,
eos_token_id=49407,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

View File

@ -712,6 +712,9 @@ class CLIPSegTextTransformer(nn.Module):
self.encoder = CLIPSegEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
@add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg
@ -762,13 +765,26 @@ class CLIPSegTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
.int()
.argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]

View File

@ -106,8 +106,8 @@ class GroupViTTextConfig(PretrainedConfig):
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
bos_token_id=49406,
eos_token_id=49407,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

View File

@ -1095,6 +1095,9 @@ class GroupViTTextTransformer(nn.Module):
self.encoder = GroupViTTextEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
@add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTTextConfig)
def forward(
@ -1144,13 +1147,26 @@ class GroupViTTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
.int()
.argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]

View File

@ -1002,6 +1002,9 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer):
epsilon=config.layer_norm_eps, name="final_layer_norm"
)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
def call(
self,
input_ids: TFModelInputType,
@ -1038,14 +1041,30 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer):
sequence_output = encoder_outputs[0]
sequence_output = self.final_layer_norm(inputs=sequence_output)
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = tf.gather_nd(
params=sequence_output,
indices=tf.stack(
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
),
)
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = tf.gather_nd(
params=sequence_output,
indices=tf.stack(
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
),
)
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = tf.gather_nd(
params=sequence_output,
indices=tf.stack(
values=(
tf.range(input_shape[0], dtype=tf.int64),
tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
),
axis=1,
),
)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]