mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Return dataset (pytorch)
This commit is contained in:
parent
7a03519975
commit
ce158a076f
@ -7,7 +7,11 @@ import numpy as np
|
||||
|
||||
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||
from .utils import DataProcessor, InputExample, InputFeatures
|
||||
from ...file_utils import is_tf_available
|
||||
from ...file_utils import is_tf_available, is_torch_available
|
||||
|
||||
if is_torch_available:
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
@ -73,7 +77,8 @@ def _is_whitespace(c):
|
||||
return False
|
||||
|
||||
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_stride, max_query_length, is_training):
|
||||
doc_stride, max_query_length, is_training,
|
||||
return_dataset=False):
|
||||
"""
|
||||
Converts a list of examples into a list of features that can be directly given as input to a model.
|
||||
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
||||
@ -84,7 +89,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
max_seq_length: The maximum sequence length of the inputs.
|
||||
doc_stride: The stride used when the context is too large and is split across several features.
|
||||
max_query_length: The maximum length of the query.
|
||||
is_training: wheter to create features for model evaluation or model training.
|
||||
is_training: whether to create features for model evaluation or model training.
|
||||
return_dataset: Default False. Either 'pt' or 'tf'.
|
||||
if 'pt': returns a torch.data.TensorDataset,
|
||||
if 'tf': returns a tf.data.Dataset
|
||||
|
||||
Returns:
|
||||
list of :class:`~transformers.data.processors.squad.SquadFeatures`
|
||||
@ -264,6 +272,31 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
|
||||
unique_id += 1
|
||||
|
||||
if return_dataset == 'pt':
|
||||
if not is_torch_available():
|
||||
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
|
||||
if not is_training:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index, all_cls_index, all_p_mask)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions,
|
||||
all_cls_index, all_p_mask)
|
||||
|
||||
return features, dataset
|
||||
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ -359,7 +392,7 @@ class SquadProcessor(DataProcessor):
|
||||
if self.dev_file is None:
|
||||
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
||||
|
||||
with open(os.path.join(data_dir, self.dev_file if filename is not None else filename), "r", encoding='utf-8') as reader:
|
||||
with open(os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding='utf-8') as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
return self._create_examples(input_data, "dev")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user