mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Move DataCollatorForMultipleChoice
from the docs to the package (#34763)
* Add implementation for DataCollatorForMultipleChoice based on docs. * Add DataCollatorForMultipleChoice to import structure. * Remove custom DataCollatorForMultipleChoice implementations from example scripts. * Remove custom implementations of DataCollatorForMultipleChoice from docs in English, Spanish, Japanese and Korean. * Refactor torch version of DataCollatorForMultipleChoice to be more easily understandable. * Apply suggested changes and run make fixup. * fix copies, style and fixup * add missing documentation * nits * fix docstring * style * nits * isort --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
parent
35c155052d
commit
8f137b2427
@ -71,3 +71,6 @@ Examples of use can be found in the [example scripts](../examples) or [example n
|
||||
|
||||
[[autodoc]] data.data_collator.DataCollatorWithFlattening
|
||||
|
||||
# DataCollatorForMultipleChoice
|
||||
|
||||
[[autodoc]] data.data_collator.DataCollatorForMultipleChoice
|
||||
|
@ -109,99 +109,14 @@ The preprocessing function you want to create needs to:
|
||||
To apply the preprocessing function over the entire dataset, use 🤗 Datasets [`~datasets.Dataset.map`] method. You can speed up the `map` function by setting `batched=True` to process multiple elements of the dataset at once:
|
||||
|
||||
```py
|
||||
tokenized_swag = swag.map(preprocess_function, batched=True)
|
||||
>>> tokenized_swag = swag.map(preprocess_function, batched=True)
|
||||
```
|
||||
|
||||
🤗 Transformers doesn't have a data collator for multiple choice, so you'll need to adapt the [`DataCollatorWithPadding`] to create a batch of examples. It's more efficient to *dynamically pad* the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.
|
||||
|
||||
`DataCollatorForMultipleChoice` flattens all the model inputs, applies padding, and then unflattens the results:
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
To create a batch of examples, it's more efficient to *dynamically pad* the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length. [`DataCollatorForMultipleChoice`] flattens all the model inputs, applies padding, and then unflattens the results.
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
>>> from typing import Optional, Union
|
||||
>>> import torch
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorForMultipleChoice:
|
||||
... """
|
||||
... Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
... """
|
||||
|
||||
... tokenizer: PreTrainedTokenizerBase
|
||||
... padding: Union[bool, str, PaddingStrategy] = True
|
||||
... max_length: Optional[int] = None
|
||||
... pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
... def __call__(self, features):
|
||||
... label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
... labels = [feature.pop(label_name) for feature in features]
|
||||
... batch_size = len(features)
|
||||
... num_choices = len(features[0]["input_ids"])
|
||||
... flattened_features = [
|
||||
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
... ]
|
||||
... flattened_features = sum(flattened_features, [])
|
||||
|
||||
... batch = self.tokenizer.pad(
|
||||
... flattened_features,
|
||||
... padding=self.padding,
|
||||
... max_length=self.max_length,
|
||||
... pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
||||
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
... return batch
|
||||
>>> from transformers import DataCollatorForMultipleChoice
|
||||
>>> collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
|
||||
```
|
||||
</pt>
|
||||
<tf>
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
>>> from typing import Optional, Union
|
||||
>>> import tensorflow as tf
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorForMultipleChoice:
|
||||
... """
|
||||
... Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
... """
|
||||
|
||||
... tokenizer: PreTrainedTokenizerBase
|
||||
... padding: Union[bool, str, PaddingStrategy] = True
|
||||
... max_length: Optional[int] = None
|
||||
... pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
... def __call__(self, features):
|
||||
... label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
... labels = [feature.pop(label_name) for feature in features]
|
||||
... batch_size = len(features)
|
||||
... num_choices = len(features[0]["input_ids"])
|
||||
... flattened_features = [
|
||||
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
... ]
|
||||
... flattened_features = sum(flattened_features, [])
|
||||
|
||||
... batch = self.tokenizer.pad(
|
||||
... flattened_features,
|
||||
... padding=self.padding,
|
||||
... max_length=self.max_length,
|
||||
... pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
... return_tensors="tf",
|
||||
... )
|
||||
|
||||
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
|
||||
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
|
||||
... return batch
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Evaluate
|
||||
|
||||
@ -271,7 +186,7 @@ At this point, only three steps remain:
|
||||
... train_dataset=tokenized_swag["train"],
|
||||
... eval_dataset=tokenized_swag["validation"],
|
||||
... processing_class=tokenizer,
|
||||
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
|
||||
... data_collator=collator,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
|
@ -91,99 +91,14 @@ Usa la función [`~datasets.Dataset.map`] de 🤗 Datasets para aplicarle la fun
|
||||
tokenized_swag = swag.map(preprocess_function, batched=True)
|
||||
```
|
||||
|
||||
🤗 Transformers no tiene un collator de datos para la tarea de selección múltiple, así que tendrías que crear uno. Puedes adaptar el [`DataCollatorWithPadding`] para crear un lote de ejemplos para selección múltiple. Este también
|
||||
le *añadirá relleno de manera dinámica* a tu texto y a las etiquetas para que tengan la longitud del elemento más largo en su lote, de forma que tengan una longitud uniforme. Aunque es posible rellenar el texto en la función `tokenizer` haciendo
|
||||
Para crear un lote de ejemplos para selección múltiple, este también le *añadirá relleno de manera dinámica* a tu texto y a las etiquetas para que tengan la longitud del elemento más largo en su lote, de forma que tengan una longitud uniforme. Aunque es posible rellenar el texto en la función `tokenizer` haciendo
|
||||
`padding=True`, el rellenado dinámico es más eficiente.
|
||||
|
||||
El `DataCollatorForMultipleChoice` aplanará todas las entradas del modelo, les aplicará relleno y luego des-aplanará los resultados:
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
El [`DataCollatorForMultipleChoice`] aplanará todas las entradas del modelo, les aplicará relleno y luego des-aplanará los resultados.
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
>>> from typing import Optional, Union
|
||||
>>> import torch
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorForMultipleChoice:
|
||||
... """
|
||||
... Collator de datos que le añadirá relleno de forma automática a las entradas recibidas para
|
||||
... una tarea de selección múltiple.
|
||||
... """
|
||||
|
||||
... tokenizer: PreTrainedTokenizerBase
|
||||
... padding: Union[bool, str, PaddingStrategy] = True
|
||||
... max_length: Optional[int] = None
|
||||
... pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
... def __call__(self, features):
|
||||
... label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
... labels = [feature.pop(label_name) for feature in features]
|
||||
... batch_size = len(features)
|
||||
... num_choices = len(features[0]["input_ids"])
|
||||
... flattened_features = [
|
||||
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
... ]
|
||||
... flattened_features = sum(flattened_features, [])
|
||||
|
||||
... batch = self.tokenizer.pad(
|
||||
... flattened_features,
|
||||
... padding=self.padding,
|
||||
... max_length=self.max_length,
|
||||
... pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
||||
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
... return batch
|
||||
>>> from transformers import DataCollatorForMultipleChoice
|
||||
>>> collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
|
||||
```
|
||||
</pt>
|
||||
<tf>
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
>>> from typing import Optional, Union
|
||||
>>> import tensorflow as tf
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorForMultipleChoice:
|
||||
... """
|
||||
... Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
... """
|
||||
|
||||
... tokenizer: PreTrainedTokenizerBase
|
||||
... padding: Union[bool, str, PaddingStrategy] = True
|
||||
... max_length: Optional[int] = None
|
||||
... pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
... def __call__(self, features):
|
||||
... label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
... labels = [feature.pop(label_name) for feature in features]
|
||||
... batch_size = len(features)
|
||||
... num_choices = len(features[0]["input_ids"])
|
||||
... flattened_features = [
|
||||
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
... ]
|
||||
... flattened_features = sum(flattened_features, [])
|
||||
|
||||
... batch = self.tokenizer.pad(
|
||||
... flattened_features,
|
||||
... padding=self.padding,
|
||||
... max_length=self.max_length,
|
||||
... pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
... return_tensors="tf",
|
||||
... )
|
||||
|
||||
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
|
||||
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
|
||||
... return batch
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Entrenamiento
|
||||
|
||||
@ -226,7 +141,7 @@ En este punto, solo quedan tres pasos:
|
||||
... train_dataset=tokenized_swag["train"],
|
||||
... eval_dataset=tokenized_swag["validation"],
|
||||
... processing_class=tokenizer,
|
||||
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
|
||||
... data_collator=collator,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
|
@ -113,96 +113,11 @@ pip install transformers datasets evaluate
|
||||
tokenized_swag = swag.map(preprocess_function, batched=True)
|
||||
```
|
||||
|
||||
🤗 Transformers には多肢選択用のデータ照合器がないため、[`DataCollatorWithPadding`] を調整してサンプルのバッチを作成する必要があります。データセット全体を最大長までパディングするのではなく、照合中にバッチ内の最長の長さまで文を *動的にパディング* する方が効率的です。
|
||||
|
||||
`DataCollatorForMultipleChoice` は、すべてのモデル入力を平坦化し、パディングを適用して、結果を非平坦化します。
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
[`DataCollatorForMultipleChoice`] は、すべてのモデル入力を平坦化し、パディングを適用して、結果を非平坦化します。
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
>>> from typing import Optional, Union
|
||||
>>> import torch
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorForMultipleChoice:
|
||||
... """
|
||||
... Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
... """
|
||||
|
||||
... tokenizer: PreTrainedTokenizerBase
|
||||
... padding: Union[bool, str, PaddingStrategy] = True
|
||||
... max_length: Optional[int] = None
|
||||
... pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
... def __call__(self, features):
|
||||
... label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
... labels = [feature.pop(label_name) for feature in features]
|
||||
... batch_size = len(features)
|
||||
... num_choices = len(features[0]["input_ids"])
|
||||
... flattened_features = [
|
||||
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
... ]
|
||||
... flattened_features = sum(flattened_features, [])
|
||||
|
||||
... batch = self.tokenizer.pad(
|
||||
... flattened_features,
|
||||
... padding=self.padding,
|
||||
... max_length=self.max_length,
|
||||
... pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
||||
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
... return batch
|
||||
>>> from transformers import DataCollatorForMultipleChoice
|
||||
>>> collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
|
||||
```
|
||||
</pt>
|
||||
<tf>
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
>>> from typing import Optional, Union
|
||||
>>> import tensorflow as tf
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorForMultipleChoice:
|
||||
... """
|
||||
... Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
... """
|
||||
|
||||
... tokenizer: PreTrainedTokenizerBase
|
||||
... padding: Union[bool, str, PaddingStrategy] = True
|
||||
... max_length: Optional[int] = None
|
||||
... pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
... def __call__(self, features):
|
||||
... label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
... labels = [feature.pop(label_name) for feature in features]
|
||||
... batch_size = len(features)
|
||||
... num_choices = len(features[0]["input_ids"])
|
||||
... flattened_features = [
|
||||
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
... ]
|
||||
... flattened_features = sum(flattened_features, [])
|
||||
|
||||
... batch = self.tokenizer.pad(
|
||||
... flattened_features,
|
||||
... padding=self.padding,
|
||||
... max_length=self.max_length,
|
||||
... pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
... return_tensors="tf",
|
||||
... )
|
||||
|
||||
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
|
||||
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
|
||||
... return batch
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Evaluate
|
||||
|
||||
@ -272,7 +187,7 @@ tokenized_swag = swag.map(preprocess_function, batched=True)
|
||||
... train_dataset=tokenized_swag["train"],
|
||||
... eval_dataset=tokenized_swag["validation"],
|
||||
... processing_class=tokenizer,
|
||||
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
|
||||
... data_collator=collator,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
|
@ -112,96 +112,11 @@ pip install transformers datasets evaluate
|
||||
tokenized_swag = swag.map(preprocess_function, batched=True)
|
||||
```
|
||||
|
||||
🤗 Transformers에는 객관식용 데이터 콜레이터가 없으므로 예제 배치를 만들려면 [`DataCollatorWithPadding`]을 조정해야 합니다. 데이터 정렬 중에 전체 데이터 집합을 최대 길이로 패딩하는 대신 배치 중 가장 긴 길이로 문장을 *동적 패딩*하는 것이 더 효율적입니다.
|
||||
|
||||
`DataCollatorForMultipleChoice`는 모든 모델 입력을 평탄화하고 패딩을 적용하며 그 결과를 결과를 다차원화합니다:
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
[`DataCollatorForMultipleChoice`]는 모든 모델 입력을 평탄화하고 패딩을 적용하며 그 결과를 결과를 다차원화합니다:
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
>>> from typing import Optional, Union
|
||||
>>> import torch
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorForMultipleChoice:
|
||||
... """
|
||||
... Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
... """
|
||||
|
||||
... tokenizer: PreTrainedTokenizerBase
|
||||
... padding: Union[bool, str, PaddingStrategy] = True
|
||||
... max_length: Optional[int] = None
|
||||
... pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
... def __call__(self, features):
|
||||
... label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
... labels = [feature.pop(label_name) for feature in features]
|
||||
... batch_size = len(features)
|
||||
... num_choices = len(features[0]["input_ids"])
|
||||
... flattened_features = [
|
||||
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
... ]
|
||||
... flattened_features = sum(flattened_features, [])
|
||||
|
||||
... batch = self.tokenizer.pad(
|
||||
... flattened_features,
|
||||
... padding=self.padding,
|
||||
... max_length=self.max_length,
|
||||
... pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
||||
... batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
... batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
... return batch
|
||||
>>> from transformers import DataCollatorForMultipleChoice
|
||||
>>> collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
|
||||
```
|
||||
</pt>
|
||||
<tf>
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
|
||||
>>> from typing import Optional, Union
|
||||
>>> import tensorflow as tf
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorForMultipleChoice:
|
||||
... """
|
||||
... Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
... """
|
||||
|
||||
... tokenizer: PreTrainedTokenizerBase
|
||||
... padding: Union[bool, str, PaddingStrategy] = True
|
||||
... max_length: Optional[int] = None
|
||||
... pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
... def __call__(self, features):
|
||||
... label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
... labels = [feature.pop(label_name) for feature in features]
|
||||
... batch_size = len(features)
|
||||
... num_choices = len(features[0]["input_ids"])
|
||||
... flattened_features = [
|
||||
... [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
... ]
|
||||
... flattened_features = sum(flattened_features, [])
|
||||
|
||||
... batch = self.tokenizer.pad(
|
||||
... flattened_features,
|
||||
... padding=self.padding,
|
||||
... max_length=self.max_length,
|
||||
... pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
... return_tensors="tf",
|
||||
... )
|
||||
|
||||
... batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
|
||||
... batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
|
||||
... return batch
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## 평가 하기[[evaluate]]
|
||||
|
||||
@ -271,7 +186,7 @@ tokenized_swag = swag.map(preprocess_function, batched=True)
|
||||
... train_dataset=tokenized_swag["train"],
|
||||
... eval_dataset=tokenized_swag["validation"],
|
||||
... processing_class=tokenizer,
|
||||
... data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
|
||||
... data_collator=collator,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
|
@ -23,11 +23,10 @@ import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
@ -35,15 +34,15 @@ from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoTokenizer,
|
||||
DataCollatorForMultipleChoice,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
|
||||
from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
@ -165,63 +164,6 @@ class DataTrainingArguments:
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForMultipleChoice:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
|
||||
if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
labels = [feature.pop(label_name) for feature in features]
|
||||
batch_size = len(features)
|
||||
num_choices = len(features[0]["input_ids"])
|
||||
flattened_features = [
|
||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
]
|
||||
flattened_features = list(chain(*flattened_features))
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flattened_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Un-flatten
|
||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
# Add back labels
|
||||
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
return batch
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
@ -425,7 +367,9 @@ def main():
|
||||
data_collator = (
|
||||
default_data_collator
|
||||
if data_args.pad_to_max_length
|
||||
else DataCollatorForMultipleChoice(tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||
else DataCollatorForMultipleChoice(
|
||||
tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None, return_tensors="pt"
|
||||
)
|
||||
)
|
||||
|
||||
# Metric
|
||||
|
@ -24,10 +24,8 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@ -47,12 +45,12 @@ from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
DataCollatorForMultipleChoice,
|
||||
SchedulerType,
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
|
||||
from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
@ -226,63 +224,6 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForMultipleChoice:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
|
||||
if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
labels = [feature.pop(label_name) for feature in features]
|
||||
batch_size = len(features)
|
||||
num_choices = len(features[0]["input_ids"])
|
||||
flattened_features = [
|
||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
]
|
||||
flattened_features = list(chain(*flattened_features))
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flattened_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Un-flatten
|
||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
# Add back labels
|
||||
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
return batch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
@ -480,7 +421,9 @@ def main():
|
||||
pad_to_multiple_of = 8
|
||||
else:
|
||||
pad_to_multiple_of = None
|
||||
data_collator = DataCollatorForMultipleChoice(tokenizer, pad_to_multiple_of=pad_to_multiple_of)
|
||||
data_collator = DataCollatorForMultipleChoice(
|
||||
tokenizer, pad_to_multiple_of=pad_to_multiple_of, return_tensors="pt"
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
||||
|
@ -23,21 +23,18 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from IPython.display import clear_output, Image, display\n",
|
||||
"import PIL.Image\n",
|
||||
"import io\n",
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import PIL.Image\n",
|
||||
"from IPython.display import Image, display\n",
|
||||
"from modeling_frcnn import GeneralizedRCNN\n",
|
||||
"from processing_image import Preprocess\n",
|
||||
"from visualizing_image import SingleImageViz\n",
|
||||
"from modeling_frcnn import GeneralizedRCNN\n",
|
||||
"from utils import Config\n",
|
||||
"\n",
|
||||
"import utils\n",
|
||||
"from transformers import LxmertForQuestionAnswering, LxmertTokenizer\n",
|
||||
"import wget\n",
|
||||
"import pickle\n",
|
||||
"import os\n",
|
||||
"from utils import Config\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/images/input.jpg\",\n",
|
||||
|
@ -31,19 +31,19 @@
|
||||
"source": [
|
||||
"# Includes\n",
|
||||
"\n",
|
||||
"import h5py\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"\n",
|
||||
"from scipy import sparse\n",
|
||||
"import h5py\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from scipy import sparse\n",
|
||||
"from torch import nn\n",
|
||||
"\n",
|
||||
"from transformers import *\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.chdir(\"../../\")"
|
||||
]
|
||||
},
|
||||
|
File diff suppressed because one or more lines are too long
@ -25,7 +25,7 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import tensorflow as tf
|
||||
@ -37,6 +37,7 @@ from transformers import (
|
||||
TF2_WEIGHTS_NAME,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
DataCollatorForMultipleChoice,
|
||||
DefaultDataCollator,
|
||||
HfArgumentParser,
|
||||
PushToHubCallback,
|
||||
@ -45,8 +46,7 @@ from transformers import (
|
||||
create_optimizer,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
|
||||
from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
@ -55,69 +55,6 @@ check_min_version("4.49.0.dev0")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# region Helper classes and functions
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForMultipleChoice:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
|
||||
if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
labels = [feature.pop(label_name) for feature in features]
|
||||
batch_size = len(features)
|
||||
num_choices = len(features[0]["input_ids"])
|
||||
flattened_features = [
|
||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
]
|
||||
flattened_features = list(chain(*flattened_features))
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flattened_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Un-flatten
|
||||
batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
|
||||
# Add back labels
|
||||
batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
|
||||
return batch
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Arguments
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@ -424,8 +361,7 @@ def main():
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = DefaultDataCollator(return_tensors="np")
|
||||
else:
|
||||
# custom class defined above, as HF has no data collator for multiple choice
|
||||
data_collator = DataCollatorForMultipleChoice(tokenizer)
|
||||
data_collator = DataCollatorForMultipleChoice(tokenizer, return_tensors="tf")
|
||||
# endregion
|
||||
|
||||
with training_args.strategy.scope():
|
||||
|
113
examples/training/distributed_training.py
Normal file
113
examples/training/distributed_training.py
Normal file
@ -0,0 +1,113 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Environment variables set by torch.distributed.launch
|
||||
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
|
||||
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
|
||||
WORLD_RANK = int(os.environ["RANK"])
|
||||
|
||||
LOCAL_RANK = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
|
||||
WORLD_SIZE = int(os.environ["OMPI_COMM_WORLD_SIZE"])
|
||||
WORLD_RANK = int(os.environ["OMPI_COMM_WORLD_RANK"])
|
||||
|
||||
|
||||
def run(backend):
|
||||
tensor = torch.zeros(1)
|
||||
# Need to put tensor on a GPU device for nccl backend
|
||||
if backend == "nccl":
|
||||
device = torch.device("cuda:{}".format(LOCAL_RANK))
|
||||
tensor = tensor.to(device)
|
||||
|
||||
if WORLD_RANK == 0:
|
||||
for rank_recv in range(1, WORLD_SIZE):
|
||||
dist.send(tensor=tensor, dst=rank_recv)
|
||||
print("worker_{} sent data to Rank {}\n".format(0, rank_recv))
|
||||
else:
|
||||
dist.recv(tensor=tensor, src=0)
|
||||
print("worker_{} has received data from rank {}\n".format(WORLD_RANK, 0))
|
||||
|
||||
|
||||
def init_processes(backend):
|
||||
dist.init_process_group(backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
|
||||
run(backend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility."
|
||||
)
|
||||
parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo"])
|
||||
args = parser.parse_args()
|
||||
|
||||
init_processes(backend=args.backend)
|
||||
|
||||
""""
|
||||
python-m torch.distributed.launch \
|
||||
--nproc_per_node=2 --nnodes=2 --node_rank=0 \
|
||||
test_compile.py
|
||||
|
||||
python3 -m torch.distributed.launch \
|
||||
--nproc_per_node=2 --nnodes=2 --node_rank=1 \
|
||||
--master_addr=104.171.200.62 --master_port=1234 \
|
||||
main.py \
|
||||
--backend=nccl --use_syn --batch_size=8192 --arch=resnet152
|
||||
|
||||
|
||||
|
||||
mpirun -np 4 \
|
||||
-H 104.171.200.62:2,104.171.200.182:2 \
|
||||
-x MASTER_ADDR=104.171.200.62 \
|
||||
-x MASTER_PORT=1234 \
|
||||
-x PATH \
|
||||
-bind-to none -map-by slot \
|
||||
-mca pml ob1 -mca btl ^openib \
|
||||
python3 main.py
|
||||
"""
|
||||
|
||||
|
||||
""""
|
||||
You need a host file with the name of hosts.
|
||||
for example I have arthur@ip-26-0-162-46 and arthur@ip-26-0-162-239
|
||||
|
||||
________
|
||||
hostfile
|
||||
ip-26-0-162-46 slots=8
|
||||
ip-26-0-162-239 slots=8
|
||||
________
|
||||
|
||||
mpirun --hostfile hostfile -np 16 \
|
||||
--bind-to none --map-by slot \
|
||||
-x MASTER_ADDR=<master-node-ip> \
|
||||
-x MASTER_PORT=29500 \
|
||||
-x NCCL_DEBUG=INFO \
|
||||
-x NCCL_SOCKET_IFNAME=^lo,docker0 \
|
||||
-x CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python your_script.py --backend nccl
|
||||
|
||||
|
||||
to get the master IP you need to do a few things:
|
||||
hostname -I | awk '{print $1}'
|
||||
|
||||
|
||||
Use `ping ip-26-0-162-46` to check if connected
|
||||
|
||||
26.0.162.46
|
||||
|
||||
mpirun --hostfile hostfile -np 16 \
|
||||
--bind-to none --map-by slot \
|
||||
-x MASTER_ADDR=26.0.162.46 \
|
||||
-x MASTER_PORT=29500 \
|
||||
-x NCCL_DEBUG=INFO \
|
||||
-x NCCL_SOCKET_IFNAME=^lo,docker0 \
|
||||
-x CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python your_script.py --backend nccl
|
||||
|
||||
|
||||
mpirun --hostfile hostfile -np 2 -x NCCL_DEBUG=INFO python -c "import os;print(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])" -b 8 -e 128M -f 2 -g 1
|
||||
to test your setup
|
||||
"""
|
@ -101,6 +101,7 @@ _import_structure = {
|
||||
"data.data_collator": [
|
||||
"DataCollator",
|
||||
"DataCollatorForLanguageModeling",
|
||||
"DataCollatorForMultipleChoice",
|
||||
"DataCollatorForPermutationLanguageModeling",
|
||||
"DataCollatorForSeq2Seq",
|
||||
"DataCollatorForSOP",
|
||||
@ -5187,6 +5188,7 @@ if TYPE_CHECKING:
|
||||
from .data.data_collator import (
|
||||
DataCollator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForMultipleChoice,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorForSOP,
|
||||
|
@ -609,7 +609,7 @@ def compile_jinja_template(template):
|
||||
raise ImportError("template requires jinja2 to be installed.")
|
||||
|
||||
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
||||
raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
|
||||
raise ImportError(f"template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}.")
|
||||
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
|
@ -751,7 +751,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
|
||||
if len(id2label) != num_labels:
|
||||
raise ValueError(
|
||||
f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
|
||||
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
|
||||
f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
|
||||
"one of them."
|
||||
)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from .data_collator import (
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForMultipleChoice,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorForSOP,
|
||||
|
@ -532,12 +532,95 @@ def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int]
|
||||
return result
|
||||
|
||||
|
||||
def tolist(x):
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
|
||||
x = x.numpy()
|
||||
return x.tolist()
|
||||
@dataclass
|
||||
class DataCollatorForMultipleChoice(DataCollatorMixin):
|
||||
"""
|
||||
Data collator that dynamically pads a batch of nested examples for multiple choice, so that all choices
|
||||
of all examples have the same length.
|
||||
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences according to the model's padding side and padding index
|
||||
among:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
|
||||
is provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
Pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def torch_call(self, examples: List[Dict[str, Any]]): # Refactored implementation from the docs.
|
||||
import torch
|
||||
|
||||
# Take labels out of the examples beforehand, because they aren't nested.
|
||||
label_name = "label" if "label" in examples[0].keys() else "labels"
|
||||
labels = [example.pop(label_name) for example in examples]
|
||||
|
||||
batch_size = len(examples)
|
||||
num_choices = len(examples[0]["input_ids"])
|
||||
|
||||
# Go from e.g. 2 examples of 2 choices [{input_ids: [[1], [2]]}, {input_ids: [[3], [4]]}]
|
||||
# to 4 examples [{input_ids: [1]}, {input_ids: [2]}] + [{input_ids: [3]}, {input_ids: [4]}]
|
||||
flat_examples = sum(
|
||||
([{k: v[i] for k, v in example.items()} for i in range(num_choices)] for example in examples), start=[]
|
||||
)
|
||||
|
||||
# Pad all choices of all examples as if you're padding any other batch of examples.
|
||||
batch = self.tokenizer.pad(
|
||||
flat_examples,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Reshape from B*C x L into B x C x L, and add the labels back in.
|
||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
return batch
|
||||
|
||||
def tf_call(self, features): # Implementation taken from the docs.
|
||||
import tensorflow as tf
|
||||
|
||||
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
labels = [feature.pop(label_name) for feature in features]
|
||||
batch_size = len(features)
|
||||
num_choices = len(features[0]["input_ids"])
|
||||
flattened_features = [
|
||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
]
|
||||
flattened_features = sum(flattened_features, []) # Sometimes written as list(chain(*flattened_features))
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flattened_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="tf",
|
||||
)
|
||||
|
||||
batch = {k: tf.reshape(v, (batch_size, num_choices, -1)) for k, v in batch.items()}
|
||||
batch["labels"] = tf.convert_to_tensor(labels, dtype=tf.int64)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1268,6 +1351,14 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
return inputs, labels
|
||||
|
||||
|
||||
def tolist(x):
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
|
||||
x = x.numpy()
|
||||
return x.tolist()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSOP(DataCollatorForLanguageModeling):
|
||||
"""
|
||||
|
@ -35,9 +35,7 @@ def convert_tiktoken_to_fast(encoding: Any, output_dir: str):
|
||||
|
||||
dump_tiktoken_bpe(encoding._mergeable_ranks, save_file_absolute)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`tiktoken` is required to save a `tiktoken` file. Install it with " "`pip install tiktoken`."
|
||||
)
|
||||
raise ValueError("`tiktoken` is required to save a `tiktoken` file. Install it with `pip install tiktoken`.")
|
||||
|
||||
tokenizer = TikTokenConverter(
|
||||
vocab_file=save_file_absolute, pattern=encoding._pat_str, additional_special_tokens=encoding._special_tokens
|
||||
|
@ -270,7 +270,7 @@ class BeitSelfAttention(nn.Module):
|
||||
self.config = config
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
||||
f"The hidden size {(config.hidden_size,)} is not a multiple of the number of attention "
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
|
@ -271,7 +271,7 @@ class Data2VecVisionSelfAttention(nn.Module):
|
||||
self.config = config
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
||||
f"The hidden size {(config.hidden_size,)} is not a multiple of the number of attention "
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
|
@ -183,7 +183,7 @@ class OlmoConfig(PretrainedConfig):
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
||||
f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
|
@ -166,7 +166,7 @@ class Olmo2Config(PretrainedConfig):
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
||||
f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
|
@ -2195,7 +2195,7 @@ def pytest_terminal_summary_main(tr, id):
|
||||
f.write("slowest durations\n")
|
||||
for i, rep in enumerate(dlist):
|
||||
if rep.duration < durations_min:
|
||||
f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
|
||||
f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
|
||||
break
|
||||
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
|
||||
|
||||
@ -2580,7 +2580,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
process.join(timeout=timeout)
|
||||
|
||||
if results["error"] is not None:
|
||||
test_case.fail(f'{results["error"]}')
|
||||
test_case.fail(f"{results['error']}")
|
||||
|
||||
|
||||
def run_test_using_subprocess(func):
|
||||
|
@ -182,7 +182,7 @@ class DepthProModelTester:
|
||||
model_name = model.__class__.__name__
|
||||
self.parent.assertTrue(
|
||||
diff <= 1e-03,
|
||||
msg=(f"Batched and Single row outputs are not equal in {model_name} for fov. " f"Difference={diff}."),
|
||||
msg=(f"Batched and Single row outputs are not equal in {model_name} for fov. Difference={diff}."),
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
@ -687,7 +687,7 @@ class TrainerIntegrationCommon:
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
shard_files = [
|
||||
shard_name.replace(f".{extension}", f"-{idx+1:05d}-of-{len(keys):05d}.{extension}")
|
||||
shard_name.replace(f".{extension}", f"-{idx + 1:05d}-of-{len(keys):05d}.{extension}")
|
||||
for idx in range(len(keys))
|
||||
]
|
||||
index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
|
||||
|
Loading…
Reference in New Issue
Block a user