diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index b17e626c984..338bae0c518 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -7,7 +7,11 @@ import numpy as np from ...tokenization_bert import BasicTokenizer, whitespace_tokenize from .utils import DataProcessor, InputExample, InputFeatures -from ...file_utils import is_tf_available +from ...file_utils import is_tf_available, is_torch_available + +if is_torch_available: + import torch + from torch.utils.data import TensorDataset if is_tf_available(): import tensorflow as tf @@ -73,7 +77,8 @@ def _is_whitespace(c): return False def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, - doc_stride, max_query_length, is_training): + doc_stride, max_query_length, is_training, + return_dataset=False): """ Converts a list of examples into a list of features that can be directly given as input to a model. It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. @@ -84,7 +89,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, max_seq_length: The maximum sequence length of the inputs. doc_stride: The stride used when the context is too large and is split across several features. max_query_length: The maximum length of the query. - is_training: wheter to create features for model evaluation or model training. + is_training: whether to create features for model evaluation or model training. + return_dataset: Default False. Either 'pt' or 'tf'. + if 'pt': returns a torch.data.TensorDataset, + if 'tf': returns a tf.data.Dataset Returns: list of :class:`~transformers.data.processors.squad.SquadFeatures` @@ -264,6 +272,31 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, unique_id += 1 + if return_dataset == 'pt': + if not is_torch_available(): + raise ImportError("Pytorch must be installed to return a pytorch dataset.") + + # Convert to Tensors and build dataset + all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) + all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) + all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) + all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) + + if not is_training: + all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) + dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, + all_example_index, all_cls_index, all_p_mask) + else: + all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) + all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) + dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, + all_start_positions, all_end_positions, + all_cls_index, all_p_mask) + + return features, dataset + + return features @@ -359,7 +392,7 @@ class SquadProcessor(DataProcessor): if self.dev_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") - with open(os.path.join(data_dir, self.dev_file if filename is not None else filename), "r", encoding='utf-8') as reader: + with open(os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding='utf-8') as reader: input_data = json.load(reader)["data"] return self._create_examples(input_data, "dev")