mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add FlaxCLIPTextModelWithProjection (#25254)
* Add FlaxClipTextModelWithProjection
This is necessary to support the Flax port of Stable Diffusion XL: fb6d705fb5/text_encoder_2/config.json (L3)
Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
* Use FlaxCLIPTextModelOutput
* make fix-copies again
* Apply suggestions from code review
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Use `return_dict` for consistency with other uses.
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Fix docstring example.
* Add new model to FlaxCLIPTextModelTest
* Add to IGNORE_NON_AUTO_CONFIGURED list
* Fix naming convention.
---------
Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
8968fface4
commit
cb8e3ee25f
@ -184,6 +184,11 @@ The resource should ideally demonstrate something new instead of duplicating an
|
||||
[[autodoc]] FlaxCLIPTextModel
|
||||
- __call__
|
||||
|
||||
## FlaxCLIPTextModelWithProjection
|
||||
|
||||
[[autodoc]] FlaxCLIPTextModelWithProjection
|
||||
- __call__
|
||||
|
||||
## FlaxCLIPVisionModel
|
||||
|
||||
[[autodoc]] FlaxCLIPVisionModel
|
||||
|
@ -3965,6 +3965,7 @@ else:
|
||||
"FlaxCLIPPreTrainedModel",
|
||||
"FlaxCLIPTextModel",
|
||||
"FlaxCLIPTextPreTrainedModel",
|
||||
"FlaxCLIPTextModelWithProjection",
|
||||
"FlaxCLIPVisionModel",
|
||||
"FlaxCLIPVisionPreTrainedModel",
|
||||
]
|
||||
@ -7388,6 +7389,7 @@ if TYPE_CHECKING:
|
||||
FlaxCLIPModel,
|
||||
FlaxCLIPPreTrainedModel,
|
||||
FlaxCLIPTextModel,
|
||||
FlaxCLIPTextModelWithProjection,
|
||||
FlaxCLIPTextPreTrainedModel,
|
||||
FlaxCLIPVisionModel,
|
||||
FlaxCLIPVisionPreTrainedModel,
|
||||
|
@ -94,6 +94,7 @@ else:
|
||||
"FlaxCLIPPreTrainedModel",
|
||||
"FlaxCLIPTextModel",
|
||||
"FlaxCLIPTextPreTrainedModel",
|
||||
"FlaxCLIPTextModelWithProjection",
|
||||
"FlaxCLIPVisionModel",
|
||||
"FlaxCLIPVisionPreTrainedModel",
|
||||
]
|
||||
@ -167,6 +168,7 @@ if TYPE_CHECKING:
|
||||
FlaxCLIPModel,
|
||||
FlaxCLIPPreTrainedModel,
|
||||
FlaxCLIPTextModel,
|
||||
FlaxCLIPTextModelWithProjection,
|
||||
FlaxCLIPTextPreTrainedModel,
|
||||
FlaxCLIPVisionModel,
|
||||
FlaxCLIPVisionPreTrainedModel,
|
||||
|
@ -155,6 +155,36 @@ CLIP_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxCLIPTextModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of
|
||||
[`FlaxCLIPTextModel`].
|
||||
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
||||
`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
text_embeds: jnp.ndarray = None
|
||||
last_hidden_state: jnp.ndarray = None
|
||||
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxCLIPOutput(ModelOutput):
|
||||
"""
|
||||
@ -1007,6 +1037,78 @@ append_replace_return_docstrings(
|
||||
)
|
||||
|
||||
|
||||
class FlaxCLIPTextModelWithProjectionModule(nn.Module):
|
||||
config: CLIPTextConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
|
||||
self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
text_embeds = self.text_projection(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (text_embeds, text_outputs[0]) + text_outputs[2:]
|
||||
|
||||
return FlaxCLIPTextModelOutput(
|
||||
text_embeds=text_embeds,
|
||||
last_hidden_state=text_outputs.last_hidden_state,
|
||||
hidden_states=text_outputs.hidden_states,
|
||||
attentions=text_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel):
|
||||
module_class = FlaxCLIPTextModelWithProjectionModule
|
||||
|
||||
|
||||
FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection
|
||||
|
||||
>>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> text_embeds = outputs.text_embeds
|
||||
```
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING
|
||||
)
|
||||
append_replace_return_docstrings(
|
||||
FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig
|
||||
)
|
||||
|
||||
|
||||
class FlaxCLIPVisionModule(nn.Module):
|
||||
config: CLIPVisionConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
@ -562,6 +562,13 @@ class FlaxCLIPTextModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxCLIPTextModelWithProjection(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxCLIPTextPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
@ -19,7 +19,12 @@ if is_flax_available():
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
|
||||
from transformers.models.clip.modeling_flax_clip import (
|
||||
FlaxCLIPModel,
|
||||
FlaxCLIPTextModel,
|
||||
FlaxCLIPTextModelWithProjection,
|
||||
FlaxCLIPVisionModel,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -315,7 +320,7 @@ class FlaxCLIPTextModelTester:
|
||||
|
||||
@require_flax
|
||||
class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else ()
|
||||
all_model_classes = (FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxCLIPTextModelTester(self)
|
||||
|
@ -205,6 +205,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"TFGroupViTTextModel",
|
||||
"TFGroupViTVisionModel",
|
||||
"FlaxCLIPTextModel",
|
||||
"FlaxCLIPTextModelWithProjection",
|
||||
"FlaxCLIPVisionModel",
|
||||
"FlaxWav2Vec2ForCTC",
|
||||
"DetrForSegmentation",
|
||||
|
Loading…
Reference in New Issue
Block a user