diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 405fe27d49c..2297919cd14 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -701,6 +701,9 @@ class CLIPTextTransformer(nn.Module): self.encoder = CLIPEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) def forward( @@ -750,13 +753,26 @@ class CLIPTextTransformer(nn.Module): last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) - # text_embeds.shape = [batch_size, sequence_length, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), - input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), - ] + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id) + .int() + .argmax(dim=-1), + ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index cb8ee4e7c9a..750e5b05485 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -487,6 +487,9 @@ class FlaxCLIPTextTransformer(nn.Module): self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + # For `pooled_output` computation + self.eos_token_id = self.config.eos_token_id + def __call__( self, input_ids, @@ -517,9 +520,18 @@ class FlaxCLIPTextTransformer(nn.Module): last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) - # text_embeds.shape = [batch_size, sequence_length, transformer.width] - # take features from the EOS embedding (eos_token_id is the highest number in each sequence) - pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)] + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the EOS embedding (eos_token_id is the highest number in each sequence) + pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)] + else: + # (no need to cast from bool to int after comparing to `eos_token_id`) + pooled_output = last_hidden_state[ + jnp.arange(last_hidden_state.shape[0]), (input_ids == self.eos_token_id).argmax(axis=-1) + ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] diff --git a/src/transformers/models/clip/modeling_tf_clip.py b/src/transformers/models/clip/modeling_tf_clip.py index 3452deba9cd..335b1f7da8e 100644 --- a/src/transformers/models/clip/modeling_tf_clip.py +++ b/src/transformers/models/clip/modeling_tf_clip.py @@ -494,6 +494,9 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer): epsilon=config.layer_norm_eps, name="final_layer_norm" ) + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + def call( self, input_ids: TFModelInputType, @@ -530,14 +533,30 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer): sequence_output = encoder_outputs[0] sequence_output = self.final_layer_norm(inputs=sequence_output) - # text_embeds.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - pooled_output = tf.gather_nd( - params=sequence_output, - indices=tf.stack( - values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 - ), - ) + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + pooled_output = tf.gather_nd( + params=sequence_output, + indices=tf.stack( + values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 + ), + ) + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = tf.gather_nd( + params=sequence_output, + indices=tf.stack( + values=( + tf.range(input_shape[0], dtype=tf.int64), + tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1), + ), + axis=1, + ), + ) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] diff --git a/src/transformers/models/clipseg/configuration_clipseg.py b/src/transformers/models/clipseg/configuration_clipseg.py index 2f05492a755..6d4147f8206 100644 --- a/src/transformers/models/clipseg/configuration_clipseg.py +++ b/src/transformers/models/clipseg/configuration_clipseg.py @@ -97,8 +97,8 @@ class CLIPSegTextConfig(PretrainedConfig): initializer_range=0.02, initializer_factor=1.0, pad_token_id=1, - bos_token_id=0, - eos_token_id=2, + bos_token_id=49406, + eos_token_id=49407, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 5f49160f4d3..4cab4425f18 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -712,6 +712,9 @@ class CLIPSegTextTransformer(nn.Module): self.encoder = CLIPSegEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + @add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig) # Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg @@ -762,13 +765,26 @@ class CLIPSegTextTransformer(nn.Module): last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) - # text_embeds.shape = [batch_size, sequence_length, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), - input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), - ] + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id) + .int() + .argmax(dim=-1), + ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] diff --git a/src/transformers/models/groupvit/configuration_groupvit.py b/src/transformers/models/groupvit/configuration_groupvit.py index 867377a863e..4d10a2dbb50 100644 --- a/src/transformers/models/groupvit/configuration_groupvit.py +++ b/src/transformers/models/groupvit/configuration_groupvit.py @@ -106,8 +106,8 @@ class GroupViTTextConfig(PretrainedConfig): initializer_range=0.02, initializer_factor=1.0, pad_token_id=1, - bos_token_id=0, - eos_token_id=2, + bos_token_id=49406, + eos_token_id=49407, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 89bfb8d005d..59ff60ed765 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1095,6 +1095,9 @@ class GroupViTTextTransformer(nn.Module): self.encoder = GroupViTTextEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTTextConfig) def forward( @@ -1144,13 +1147,26 @@ class GroupViTTextTransformer(nn.Module): last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) - # text_embeds.shape = [batch_size, sequence_length, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), - input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), - ] + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id) + .int() + .argmax(dim=-1), + ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] diff --git a/src/transformers/models/groupvit/modeling_tf_groupvit.py b/src/transformers/models/groupvit/modeling_tf_groupvit.py index 2c3297a8f8b..027117bdce2 100644 --- a/src/transformers/models/groupvit/modeling_tf_groupvit.py +++ b/src/transformers/models/groupvit/modeling_tf_groupvit.py @@ -1002,6 +1002,9 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer): epsilon=config.layer_norm_eps, name="final_layer_norm" ) + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + def call( self, input_ids: TFModelInputType, @@ -1038,14 +1041,30 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer): sequence_output = encoder_outputs[0] sequence_output = self.final_layer_norm(inputs=sequence_output) - # text_embeds.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - pooled_output = tf.gather_nd( - params=sequence_output, - indices=tf.stack( - values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 - ), - ) + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + pooled_output = tf.gather_nd( + params=sequence_output, + indices=tf.stack( + values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 + ), + ) + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = tf.gather_nd( + params=sequence_output, + indices=tf.stack( + values=( + tf.range(input_shape[0], dtype=tf.int64), + tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1), + ), + axis=1, + ), + ) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:]