diff --git a/tensorflow_code/run_squad.py b/tensorflow_code/run_squad.py index fb1c4b5ed80..0841afa8c82 100644 --- a/tensorflow_code/run_squad.py +++ b/tensorflow_code/run_squad.py @@ -207,7 +207,7 @@ class InputFeatures(object): self.end_position = end_position -def read_squad_examples(input_file, is_training): +def read_squad_examples(input_file, is_training, max_num=-1): """Read a SQuAD json file into a list of SquadExample.""" with tf.gfile.Open(input_file, "r") as reader: input_data = json.load(reader)["data"] @@ -219,6 +219,8 @@ def read_squad_examples(input_file, is_training): examples = [] for entry in input_data: + if max_num != -1 and len(examples) > max_num: + break for paragraph in entry["paragraphs"]: paragraph_text = paragraph["context"] doc_tokens = []