Move DataCollatorForMultipleChoice from the docs to the package (#34763)

* Add implementation for DataCollatorForMultipleChoice based on docs.

* Add DataCollatorForMultipleChoice to import structure.

* Remove custom DataCollatorForMultipleChoice implementations from example scripts.

* Remove custom implementations of DataCollatorForMultipleChoice from docs in English, Spanish, Japanese and Korean.

* Refactor torch version of DataCollatorForMultipleChoice to be more easily understandable.

* Apply suggested changes and run make fixup.

* fix copies, style and fixup

* add missing documentation

* nits

* fix docstring

* style

* nits

* isort

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Thomas Bauwens 2025-02-13 12:01:28 +01:00 committed by GitHub
parent 35c155052d
commit 8f137b2427
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 361 additions and 670 deletions

View File

@ -71,3 +71,6 @@ Examples of use can be found in the [example scripts](../examples) or [example n
[[autodoc]] data.data_collator.DataCollatorWithFlattening
# DataCollatorForMultipleChoice
[[autodoc]] data.data_collator.DataCollatorForMultipleChoice

View File

@ -109,99 +109,14 @@ The preprocessing function you want to create needs to:
To apply the preprocessing function over the entire dataset, use 🤗 Datasets [`~datasets.Dataset.map`] method. You can speed up the `map` function by setting `batched=True` to process multiple elements of the dataset at once:
```py
tokenized_swag = swag.map(preprocess_function, batched=True)
>>> tokenized_swag = swag.map(preprocess_function, batched=True)
```
🤗 Transformers doesn't have a data collator for multiple choice, so you'll need to adapt the [`DataCollatorWithPadding`] to create a batch of examples. It's more efficient to *dynamically pad* the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.
`DataCollatorForMultipleChoice` flattens all the model inputs, applies padding, and then unflattens the results:
<frameworkcontent>
<pt>
To create a batch of examples, it's more efficient to *dynamically pad* the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length. [`DataCollatorForMultipleChoice`] flattens all the model inputs, applies padding, and then unflattens the results.
```py
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import torch
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="pt",
... )
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
... return batch
>>> from transformers import DataCollatorForMultipleChoice
>>> collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
```
</pt>
<tf>
```py
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import tensorflow as tf
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="tf",
... )
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
... return batch
```
</tf>
</frameworkcontent>
## Evaluate
@ -271,7 +186,7 @@ At this point, only three steps remain:
... train_dataset=tokenized_swag["train"],
... eval_dataset=tokenized_swag["validation"],
... processing_class=tokenizer,
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
... data_collator=collator,
... compute_metrics=compute_metrics,
... )

View File

@ -91,99 +91,14 @@ Usa la función [`~datasets.Dataset.map`] de 🤗 Datasets para aplicarle la fun
tokenized_swag = swag.map(preprocess_function, batched=True)
```
🤗 Transformers no tiene un collator de datos para la tarea de selección múltiple, así que tendrías que crear uno. Puedes adaptar el [`DataCollatorWithPadding`] para crear un lote de ejemplos para selección múltiple. Este también
le *añadirá relleno de manera dinámica* a tu texto y a las etiquetas para que tengan la longitud del elemento más largo en su lote, de forma que tengan una longitud uniforme. Aunque es posible rellenar el texto en la función `tokenizer` haciendo
Para crear un lote de ejemplos para selección múltiple, este también le *añadirá relleno de manera dinámica* a tu texto y a las etiquetas para que tengan la longitud del elemento más largo en su lote, de forma que tengan una longitud uniforme. Aunque es posible rellenar el texto en la función `tokenizer` haciendo
`padding=True`, el rellenado dinámico es más eficiente.
El `DataCollatorForMultipleChoice` aplanará todas las entradas del modelo, les aplicará relleno y luego des-aplanará los resultados:
<frameworkcontent>
<pt>
El [`DataCollatorForMultipleChoice`] aplanará todas las entradas del modelo, les aplicará relleno y luego des-aplanará los resultados.
```py
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import torch
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Collator de datos que le añadirá relleno de forma automática a las entradas recibidas para
... una tarea de selección múltiple.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="pt",
... )
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
... return batch
>>> from transformers import DataCollatorForMultipleChoice
>>> collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
```
</pt>
<tf>
```py
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import tensorflow as tf
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="tf",
... )
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
... return batch
```
</tf>
</frameworkcontent>
## Entrenamiento
@ -226,7 +141,7 @@ En este punto, solo quedan tres pasos:
... train_dataset=tokenized_swag["train"],
... eval_dataset=tokenized_swag["validation"],
... processing_class=tokenizer,
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
... data_collator=collator,
... )
>>> trainer.train()

View File

@ -113,96 +113,11 @@ pip install transformers datasets evaluate
tokenized_swag = swag.map(preprocess_function, batched=True)
```
🤗 Transformers には多肢選択用のデータ照合器がないため、[`DataCollatorWithPadding`] を調整してサンプルのバッチを作成する必要があります。データセット全体を最大長までパディングするのではなく、照合中にバッチ内の最長の長さまで文を *動的にパディング* する方が効率的です。
`DataCollatorForMultipleChoice` は、すべてのモデル入力を平坦化し、パディングを適用して、結果を非平坦化します。
<frameworkcontent>
<pt>
[`DataCollatorForMultipleChoice`] は、すべてのモデル入力を平坦化し、パディングを適用して、結果を非平坦化します。
```py
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import torch
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="pt",
... )
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
... return batch
>>> from transformers import DataCollatorForMultipleChoice
>>> collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
```
</pt>
<tf>
```py
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import tensorflow as tf
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="tf",
... )
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
... return batch
```
</tf>
</frameworkcontent>
## Evaluate
@ -272,7 +187,7 @@ tokenized_swag = swag.map(preprocess_function, batched=True)
... train_dataset=tokenized_swag["train"],
... eval_dataset=tokenized_swag["validation"],
... processing_class=tokenizer,
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
... data_collator=collator,
... compute_metrics=compute_metrics,
... )

View File

@ -112,96 +112,11 @@ pip install transformers datasets evaluate
tokenized_swag = swag.map(preprocess_function, batched=True)
```
🤗 Transformers에는 객관식용 데이터 콜레이터가 없으므로 예제 배치를 만들려면 [`DataCollatorWithPadding`]을 조정해야 합니다. 데이터 정렬 중에 전체 데이터 집합을 최대 길이로 패딩하는 대신 배치 중 가장 긴 길이로 문장을 *동적 패딩*하는 것이 더 효율적입니다.
`DataCollatorForMultipleChoice`는 모든 모델 입력을 평탄화하고 패딩을 적용하며 그 결과를 결과를 다차원화합니다:
<frameworkcontent>
<pt>
[`DataCollatorForMultipleChoice`]는 모든 모델 입력을 평탄화하고 패딩을 적용하며 그 결과를 결과를 다차원화합니다:
```py
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import torch
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="pt",
... )
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
... return batch
>>> from transformers import DataCollatorForMultipleChoice
>>> collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
```
</pt>
<tf>
```py
>>> from dataclasses import dataclass
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
>>> from typing import Optional, Union
>>> import tensorflow as tf
>>> @dataclass
... class DataCollatorForMultipleChoice:
... """
... Data collator that will dynamically pad the inputs for multiple choice received.
... """
... tokenizer: PreTrainedTokenizerBase
... padding: Union[bool, str, PaddingStrategy] = True
... max_length: Optional[int] = None
... pad_to_multiple_of: Optional[int] = None
... def __call__(self, features):
... label_name = "label" if "label" in features[0].keys() else "labels"
... labels = [feature.pop(label_name) for feature in features]
... batch_size = len(features)
... num_choices = len(features[0]["input_ids"])
... flattened_features = [
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
... ]
... flattened_features = sum(flattened_features, [])
... batch = self.tokenizer.pad(
... flattened_features,
... padding=self.padding,
... max_length=self.max_length,
... pad_to_multiple_of=self.pad_to_multiple_of,
... return_tensors="tf",
... )
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
... return batch
```
</tf>
</frameworkcontent>
## 평가 하기[[evaluate]]
@ -271,7 +186,7 @@ tokenized_swag = swag.map(preprocess_function, batched=True)
... train_dataset=tokenized_swag["train"],
... eval_dataset=tokenized_swag["validation"],
... processing_class=tokenizer,
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
... data_collator=collator,
... compute_metrics=compute_metrics,
... )

View File

@ -23,11 +23,10 @@ import os
import sys
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional, Union
from typing import Optional
import datasets
import numpy as np
import torch
from datasets import load_dataset
import transformers
@ -35,15 +34,15 @@ from transformers import (
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
DataCollatorForMultipleChoice,
HfArgumentParser,
Trainer,
TrainingArguments,
default_data_collator,
set_seed,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
@ -165,63 +164,6 @@ class DataTrainingArguments:
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
@dataclass
class DataCollatorForMultipleChoice:
"""
Data collator that will dynamically pad the inputs for multiple choice received.
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
batch_size = len(features)
num_choices = len(features[0]["input_ids"])
flattened_features = [
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
]
flattened_features = list(chain(*flattened_features))
batch = self.tokenizer.pad(
flattened_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
# Un-flatten
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
# Add back labels
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
return batch
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
@ -425,7 +367,9 @@ def main():
data_collator = (
default_data_collator
if data_args.pad_to_max_length
else DataCollatorForMultipleChoice(tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
else DataCollatorForMultipleChoice(
tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None, return_tensors="pt"
)
)
# Metric

View File

@ -24,10 +24,8 @@ import logging
import math
import os
import random
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
from typing import Optional, Union
import datasets
import evaluate
@ -47,12 +45,12 @@ from transformers import (
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
PreTrainedTokenizerBase,
DataCollatorForMultipleChoice,
SchedulerType,
default_data_collator,
get_scheduler,
)
from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
@ -226,63 +224,6 @@ def parse_args():
return args
@dataclass
class DataCollatorForMultipleChoice:
"""
Data collator that will dynamically pad the inputs for multiple choice received.
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
batch_size = len(features)
num_choices = len(features[0]["input_ids"])
flattened_features = [
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
]
flattened_features = list(chain(*flattened_features))
batch = self.tokenizer.pad(
flattened_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
# Un-flatten
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
# Add back labels
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
return batch
def main():
args = parse_args()
@ -480,7 +421,9 @@ def main():
pad_to_multiple_of = 8
else:
pad_to_multiple_of = None
data_collator = DataCollatorForMultipleChoice(tokenizer, pad_to_multiple_of=pad_to_multiple_of)
data_collator = DataCollatorForMultipleChoice(
tokenizer, pad_to_multiple_of=pad_to_multiple_of, return_tensors="pt"
)
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size

View File

@ -23,21 +23,18 @@
}
],
"source": [
"from IPython.display import clear_output, Image, display\n",
"import PIL.Image\n",
"import io\n",
"import json\n",
"import torch\n",
"\n",
"import numpy as np\n",
"import PIL.Image\n",
"from IPython.display import Image, display\n",
"from modeling_frcnn import GeneralizedRCNN\n",
"from processing_image import Preprocess\n",
"from visualizing_image import SingleImageViz\n",
"from modeling_frcnn import GeneralizedRCNN\n",
"from utils import Config\n",
"\n",
"import utils\n",
"from transformers import LxmertForQuestionAnswering, LxmertTokenizer\n",
"import wget\n",
"import pickle\n",
"import os\n",
"from utils import Config\n",
"\n",
"\n",
"# URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/images/input.jpg\",\n",

View File

@ -31,19 +31,19 @@
"source": [
"# Includes\n",
"\n",
"import h5py\n",
"import os\n",
"import json\n",
"import os\n",
"from collections import OrderedDict\n",
"\n",
"from scipy import sparse\n",
"import h5py\n",
"import numpy as np\n",
"\n",
"import torch\n",
"from scipy import sparse\n",
"from torch import nn\n",
"\n",
"from transformers import *\n",
"\n",
"\n",
"os.chdir(\"../../\")"
]
},

File diff suppressed because one or more lines are too long

View File

@ -25,7 +25,7 @@ import sys
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import Optional, Union
from typing import Optional
import datasets
import tensorflow as tf
@ -37,6 +37,7 @@ from transformers import (
TF2_WEIGHTS_NAME,
AutoConfig,
AutoTokenizer,
DataCollatorForMultipleChoice,
DefaultDataCollator,
HfArgumentParser,
PushToHubCallback,
@ -45,8 +46,7 @@ from transformers import (
create_optimizer,
set_seed,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
@ -55,69 +55,6 @@ check_min_version("4.49.0.dev0")
logger = logging.getLogger(__name__)
# region Helper classes and functions
@dataclass
class DataCollatorForMultipleChoice:
"""
Data collator that will dynamically pad the inputs for multiple choice received.
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
batch_size = len(features)
num_choices = len(features[0]["input_ids"])
flattened_features = [
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
]
flattened_features = list(chain(*flattened_features))
batch = self.tokenizer.pad(
flattened_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="np",
)
# Un-flatten
batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
# Add back labels
batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
return batch
# endregion
# region Arguments
@dataclass
class ModelArguments:
@ -424,8 +361,7 @@ def main():
if data_args.pad_to_max_length:
data_collator = DefaultDataCollator(return_tensors="np")
else:
# custom class defined above, as HF has no data collator for multiple choice
data_collator = DataCollatorForMultipleChoice(tokenizer)
data_collator = DataCollatorForMultipleChoice(tokenizer, return_tensors="tf")
# endregion
with training_args.strategy.scope():

View File

@ -0,0 +1,113 @@
import argparse
import os
import torch
import torch.distributed as dist
# Environment variables set by torch.distributed.launch
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
WORLD_RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
WORLD_SIZE = int(os.environ["OMPI_COMM_WORLD_SIZE"])
WORLD_RANK = int(os.environ["OMPI_COMM_WORLD_RANK"])
def run(backend):
tensor = torch.zeros(1)
# Need to put tensor on a GPU device for nccl backend
if backend == "nccl":
device = torch.device("cuda:{}".format(LOCAL_RANK))
tensor = tensor.to(device)
if WORLD_RANK == 0:
for rank_recv in range(1, WORLD_SIZE):
dist.send(tensor=tensor, dst=rank_recv)
print("worker_{} sent data to Rank {}\n".format(0, rank_recv))
else:
dist.recv(tensor=tensor, src=0)
print("worker_{} has received data from rank {}\n".format(WORLD_RANK, 0))
def init_processes(backend):
dist.init_process_group(backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
run(backend)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility."
)
parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo"])
args = parser.parse_args()
init_processes(backend=args.backend)
""""
python-m torch.distributed.launch \
--nproc_per_node=2 --nnodes=2 --node_rank=0 \
test_compile.py
python3 -m torch.distributed.launch \
--nproc_per_node=2 --nnodes=2 --node_rank=1 \
--master_addr=104.171.200.62 --master_port=1234 \
main.py \
--backend=nccl --use_syn --batch_size=8192 --arch=resnet152
mpirun -np 4 \
-H 104.171.200.62:2,104.171.200.182:2 \
-x MASTER_ADDR=104.171.200.62 \
-x MASTER_PORT=1234 \
-x PATH \
-bind-to none -map-by slot \
-mca pml ob1 -mca btl ^openib \
python3 main.py
"""
""""
You need a host file with the name of hosts.
for example I have arthur@ip-26-0-162-46 and arthur@ip-26-0-162-239
________
hostfile
ip-26-0-162-46 slots=8
ip-26-0-162-239 slots=8
________
mpirun --hostfile hostfile -np 16 \
--bind-to none --map-by slot \
-x MASTER_ADDR=<master-node-ip> \
-x MASTER_PORT=29500 \
-x NCCL_DEBUG=INFO \
-x NCCL_SOCKET_IFNAME=^lo,docker0 \
-x CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python your_script.py --backend nccl
to get the master IP you need to do a few things:
hostname -I | awk '{print $1}'
Use `ping ip-26-0-162-46` to check if connected
26.0.162.46
mpirun --hostfile hostfile -np 16 \
--bind-to none --map-by slot \
-x MASTER_ADDR=26.0.162.46 \
-x MASTER_PORT=29500 \
-x NCCL_DEBUG=INFO \
-x NCCL_SOCKET_IFNAME=^lo,docker0 \
-x CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python your_script.py --backend nccl
mpirun --hostfile hostfile -np 2 -x NCCL_DEBUG=INFO python -c "import os;print(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])" -b 8 -e 128M -f 2 -g 1
to test your setup
"""

View File

@ -101,6 +101,7 @@ _import_structure = {
"data.data_collator": [
"DataCollator",
"DataCollatorForLanguageModeling",
"DataCollatorForMultipleChoice",
"DataCollatorForPermutationLanguageModeling",
"DataCollatorForSeq2Seq",
"DataCollatorForSOP",
@ -5187,6 +5188,7 @@ if TYPE_CHECKING:
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForMultipleChoice,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSeq2Seq,
DataCollatorForSOP,

View File

@ -609,7 +609,7 @@ def compile_jinja_template(template):
raise ImportError("template requires jinja2 to be installed.")
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
raise ImportError(f"template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}.")
def raise_exception(message):
raise TemplateError(message)

View File

@ -751,7 +751,7 @@ class PretrainedConfig(PushToHubMixin):
id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
if len(id2label) != num_labels:
raise ValueError(
f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
"one of them."
)

View File

@ -14,6 +14,7 @@
from .data_collator import (
DataCollatorForLanguageModeling,
DataCollatorForMultipleChoice,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSeq2Seq,
DataCollatorForSOP,

View File

@ -532,12 +532,95 @@ def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int]
return result
def tolist(x):
if isinstance(x, list):
return x
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
x = x.numpy()
return x.tolist()
@dataclass
class DataCollatorForMultipleChoice(DataCollatorMixin):
"""
Data collator that dynamically pads a batch of nested examples for multiple choice, so that all choices
of all examples have the same length.
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences according to the model's padding side and padding index
among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
is provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (`int`, *optional*):
Pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
return_tensors (`str`, *optional*, defaults to `"pt"`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def torch_call(self, examples: List[Dict[str, Any]]): # Refactored implementation from the docs.
import torch
# Take labels out of the examples beforehand, because they aren't nested.
label_name = "label" if "label" in examples[0].keys() else "labels"
labels = [example.pop(label_name) for example in examples]
batch_size = len(examples)
num_choices = len(examples[0]["input_ids"])
# Go from e.g. 2 examples of 2 choices [{input_ids: [[1], [2]]}, {input_ids: [[3], [4]]}]
# to 4 examples [{input_ids: [1]}, {input_ids: [2]}] + [{input_ids: [3]}, {input_ids: [4]}]
flat_examples = sum(
([{k: v[i] for k, v in example.items()} for i in range(num_choices)] for example in examples), start=[]
)
# Pad all choices of all examples as if you're padding any other batch of examples.
batch = self.tokenizer.pad(
flat_examples,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
# Reshape from B*C x L into B x C x L, and add the labels back in.
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
return batch
def tf_call(self, features): # Implementation taken from the docs.
import tensorflow as tf
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
batch_size = len(features)
num_choices = len(features[0]["input_ids"])
flattened_features = [
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
]
flattened_features = sum(flattened_features, []) # Sometimes written as list(chain(*flattened_features))
batch = self.tokenizer.pad(
flattened_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="tf",
)
batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
return batch
@dataclass
@ -1268,6 +1351,14 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return inputs, labels
def tolist(x):
if isinstance(x, list):
return x
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
x = x.numpy()
return x.tolist()
@dataclass
class DataCollatorForSOP(DataCollatorForLanguageModeling):
"""

View File

@ -35,9 +35,7 @@ def convert_tiktoken_to_fast(encoding: Any, output_dir: str):
dump_tiktoken_bpe(encoding._mergeable_ranks, save_file_absolute)
except ImportError:
raise ValueError(
"`tiktoken` is required to save a `tiktoken` file. Install it with " "`pip install tiktoken`."
)
raise ValueError("`tiktoken` is required to save a `tiktoken` file. Install it with `pip install tiktoken`.")
tokenizer = TikTokenConverter(
vocab_file=save_file_absolute, pattern=encoding._pat_str, additional_special_tokens=encoding._special_tokens

View File

@ -270,7 +270,7 @@ class BeitSelfAttention(nn.Module):
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"The hidden size {(config.hidden_size,)} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)

View File

@ -271,7 +271,7 @@ class Data2VecVisionSelfAttention(nn.Module):
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"The hidden size {(config.hidden_size,)} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)

View File

@ -183,7 +183,7 @@ class OlmoConfig(PretrainedConfig):
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)

View File

@ -166,7 +166,7 @@ class Olmo2Config(PretrainedConfig):
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)

View File

@ -2195,7 +2195,7 @@ def pytest_terminal_summary_main(tr, id):
f.write("slowest durations\n")
for i, rep in enumerate(dlist):
if rep.duration < durations_min:
f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
break
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
@ -2580,7 +2580,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
process.join(timeout=timeout)
if results["error"] is not None:
test_case.fail(f'{results["error"]}')
test_case.fail(f"{results['error']}")
def run_test_using_subprocess(func):

View File

@ -182,7 +182,7 @@ class DepthProModelTester:
model_name = model.__class__.__name__
self.parent.assertTrue(
diff <= 1e-03,
msg=(f"Batched and Single row outputs are not equal in {model_name} for fov. " f"Difference={diff}."),
msg=(f"Batched and Single row outputs are not equal in {model_name} for fov. Difference={diff}."),
)
def prepare_config_and_inputs_for_common(self):

View File

@ -687,7 +687,7 @@ class TrainerIntegrationCommon:
keys = list(state_dict.keys())
shard_files = [
shard_name.replace(f".{extension}", f"-{idx+1:05d}-of-{len(keys):05d}.{extension}")
shard_name.replace(f".{extension}", f"-{idx + 1:05d}-of-{len(keys):05d}.{extension}")
for idx in range(len(keys))
]
index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}