Return dataset (pytorch)

This commit is contained in:
LysandreJik 2019-12-04 17:55:52 -05:00
parent 7a03519975
commit ce158a076f

View File

@ -7,7 +7,11 @@ import numpy as np
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
from .utils import DataProcessor, InputExample, InputFeatures
from ...file_utils import is_tf_available
from ...file_utils import is_tf_available, is_torch_available
if is_torch_available:
import torch
from torch.utils.data import TensorDataset
if is_tf_available():
import tensorflow as tf
@ -73,7 +77,8 @@ def _is_whitespace(c):
return False
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training):
doc_stride, max_query_length, is_training,
return_dataset=False):
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
@ -84,7 +89,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
max_seq_length: The maximum sequence length of the inputs.
doc_stride: The stride used when the context is too large and is split across several features.
max_query_length: The maximum length of the query.
is_training: wheter to create features for model evaluation or model training.
is_training: whether to create features for model evaluation or model training.
return_dataset: Default False. Either 'pt' or 'tf'.
if 'pt': returns a torch.data.TensorDataset,
if 'tf': returns a tf.data.Dataset
Returns:
list of :class:`~transformers.data.processors.squad.SquadFeatures`
@ -264,6 +272,31 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
unique_id += 1
if return_dataset == 'pt':
if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if not is_training:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_example_index, all_cls_index, all_p_mask)
else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions,
all_cls_index, all_p_mask)
return features, dataset
return features
@ -359,7 +392,7 @@ class SquadProcessor(DataProcessor):
if self.dev_file is None:
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
with open(os.path.join(data_dir, self.dev_file if filename is not None else filename), "r", encoding='utf-8') as reader:
with open(os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding='utf-8') as reader:
input_data = json.load(reader)["data"]
return self._create_examples(input_data, "dev")