mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
minor fixes (#14026)
This commit is contained in:
parent
f5af873617
commit
84ad6af49a
@ -15,6 +15,7 @@
|
|||||||
""" PyTorch CLIP model. """
|
""" PyTorch CLIP model. """
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -71,6 +72,7 @@ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
|||||||
return (caption_loss + image_loss) / 2.0
|
return (caption_loss + image_loss) / 2.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class CLIPOutput(ModelOutput):
|
class CLIPOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -297,10 +299,9 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(batch, seq_len, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
:obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
:obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
|
||||||
:obj:`(config.encoder_attention_heads,)`.
|
:obj:`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
@ -497,7 +498,6 @@ class CLIPEncoder(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: CLIPConfig
|
config: CLIPConfig
|
||||||
embed_tokens (nn.Embedding): output embedding
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: CLIPConfig):
|
def __init__(self, config: CLIPConfig):
|
||||||
@ -517,7 +517,7 @@ class CLIPEncoder(nn.Module):
|
|||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
into associated vectors than the model's internal embedding lookup matrix.
|
into associated vectors than the model's internal embedding lookup matrix.
|
||||||
|
@ -102,7 +102,8 @@ class CLIPVisionModelTester:
|
|||||||
model = CLIPVisionModel(config=config)
|
model = CLIPVisionModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values)
|
||||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||||
image_size = (self.image_size, self.image_size)
|
image_size = (self.image_size, self.image_size)
|
||||||
patch_size = (self.patch_size, self.patch_size)
|
patch_size = (self.patch_size, self.patch_size)
|
||||||
@ -350,8 +351,9 @@ class CLIPTextModelTester:
|
|||||||
model = CLIPTextModel(config=config)
|
model = CLIPTextModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
with torch.no_grad():
|
||||||
result = model(input_ids)
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
result = model(input_ids)
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||||
|
|
||||||
@ -429,7 +431,8 @@ class CLIPModelTester:
|
|||||||
|
|
||||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||||
model = CLIPModel(config).to(torch_device).eval()
|
model = CLIPModel(config).to(torch_device).eval()
|
||||||
result = model(input_ids, pixel_values, attention_mask)
|
with torch.no_grad():
|
||||||
|
result = model(input_ids, pixel_values, attention_mask)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user