Add FlaxCLIPTextModelWithProjection (#25254)

* Add FlaxClipTextModelWithProjection

This is necessary to support the Flax port of Stable Diffusion XL: fb6d705fb5/text_encoder_2/config.json (L3)

Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>

* Use FlaxCLIPTextModelOutput

* make fix-copies again

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Use `return_dict` for consistency with other uses.

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Fix docstring example.

* Add new model to FlaxCLIPTextModelTest

* Add to IGNORE_NON_AUTO_CONFIGURED list

* Fix naming convention.

---------

Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Pedro Cuenca 2023-08-25 10:58:14 +02:00 committed by GitHub
parent 8968fface4
commit cb8e3ee25f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 126 additions and 2 deletions

View File

@ -184,6 +184,11 @@ The resource should ideally demonstrate something new instead of duplicating an
[[autodoc]] FlaxCLIPTextModel
- __call__
## FlaxCLIPTextModelWithProjection
[[autodoc]] FlaxCLIPTextModelWithProjection
- __call__
## FlaxCLIPVisionModel
[[autodoc]] FlaxCLIPVisionModel

View File

@ -3965,6 +3965,7 @@ else:
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
@ -7388,6 +7389,7 @@ if TYPE_CHECKING:
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,

View File

@ -94,6 +94,7 @@ else:
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
@ -167,6 +168,7 @@ if TYPE_CHECKING:
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,

View File

@ -155,6 +155,36 @@ CLIP_INPUTS_DOCSTRING = r"""
"""
@flax.struct.dataclass
class FlaxCLIPTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPTextModel`].
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: jnp.ndarray = None
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxCLIPOutput(ModelOutput):
"""
@ -1007,6 +1037,78 @@ append_replace_return_docstrings(
)
class FlaxCLIPTextModelWithProjectionModule(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)
if not return_dict:
return (text_embeds, text_outputs[0]) + text_outputs[2:]
return FlaxCLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)
class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel):
module_class = FlaxCLIPTextModelWithProjectionModule
FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection
>>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
>>> outputs = model(**inputs)
>>> text_embeds = outputs.text_embeds
```
"""
overwrite_call_docstring(
FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING
)
append_replace_return_docstrings(
FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig
)
class FlaxCLIPVisionModule(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32

View File

@ -562,6 +562,13 @@ class FlaxCLIPTextModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxCLIPTextModelWithProjection(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPTextPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]

View File

@ -19,7 +19,12 @@ if is_flax_available():
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.clip.modeling_flax_clip import FlaxCLIPModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from transformers.models.clip.modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPVisionModel,
)
if is_torch_available():
import torch
@ -315,7 +320,7 @@ class FlaxCLIPTextModelTester:
@require_flax
class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else ()
all_model_classes = (FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxCLIPTextModelTester(self)

View File

@ -205,6 +205,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"TFGroupViTTextModel",
"TFGroupViTVisionModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextModelWithProjection",
"FlaxCLIPVisionModel",
"FlaxWav2Vec2ForCTC",
"DetrForSegmentation",