mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
131 lines
4.3 KiB
Python
131 lines
4.3 KiB
Python
# coding=utf-8
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Copyright (c) HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
from collections import Counter
|
|
from PIL import Image
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import Dataset
|
|
|
|
POOLING_BREAKDOWN = {
|
|
1: (1, 1),
|
|
2: (2, 1),
|
|
3: (3, 1),
|
|
4: (2, 2),
|
|
5: (5, 1),
|
|
6: (3, 2),
|
|
7: (7, 1),
|
|
8: (4, 2),
|
|
9: (3, 3)
|
|
}
|
|
|
|
|
|
class ImageEncoder(nn.Module):
|
|
def __init__(self, args):
|
|
super(ImageEncoder, self).__init__()
|
|
model = torchvision.models.resnet152(pretrained=True)
|
|
modules = list(model.children())[:-2]
|
|
self.model = nn.Sequential(*modules)
|
|
self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[args.num_image_embeds])
|
|
|
|
def forward(self, x):
|
|
# Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048
|
|
out = self.pool(self.model(x))
|
|
out = torch.flatten(out, start_dim=2)
|
|
out = out.transpose(1, 2).contiguous()
|
|
return out # BxNx2048
|
|
|
|
|
|
|
|
class JsonlDataset(Dataset):
|
|
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
|
|
self.data = [json.loads(l) for l in open(data_path)]
|
|
self.data_dir = os.path.dirname(data_path)
|
|
self.tokenizer = tokenizer
|
|
self.labels = labels
|
|
self.n_classes = len(labels)
|
|
self.max_seq_length = max_seq_length
|
|
|
|
self.transforms = transforms
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
|
|
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
|
|
sentence = sentence[:self.max_seq_length]
|
|
|
|
label = torch.zeros(self.n_classes)
|
|
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
|
|
|
|
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
|
|
image = self.transforms(image)
|
|
|
|
return {"image_start_token": start_token, "image_end_token": end_token,
|
|
"sentence": sentence, "image": image, "label": label}
|
|
|
|
def get_label_frequencies(self):
|
|
label_freqs = Counter()
|
|
for row in self.data:
|
|
label_freqs.update(row["label"])
|
|
return label_freqs
|
|
|
|
|
|
def collate_fn(batch):
|
|
lens = [len(row["sentence"]) for row in batch]
|
|
bsz, max_seq_len = len(batch), max(lens)
|
|
|
|
mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
|
|
text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
|
|
|
|
for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
|
|
text_tensor[i_batch, :length] = input_row["sentence"]
|
|
mask_tensor[i_batch, :length] = 1
|
|
|
|
img_tensor = torch.stack([row["image"] for row in batch])
|
|
tgt_tensor = torch.stack([row["label"] for row in batch])
|
|
img_start_token = torch.stack([row["image_start_token"] for row in batch])
|
|
img_end_token = torch.stack([row["image_end_token"] for row in batch])
|
|
|
|
return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor
|
|
|
|
|
|
def get_mmimdb_labels():
|
|
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance',
|
|
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure',
|
|
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music',
|
|
'Musical', 'Animation', 'Biography', 'Film-Noir']
|
|
|
|
|
|
def get_image_transforms():
|
|
return transforms.Compose(
|
|
[
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.46777044, 0.44531429, 0.40661017],
|
|
std=[0.12221994, 0.12145835, 0.14380469],
|
|
),
|
|
]
|
|
)
|