default output dir to documents dir

This commit is contained in:
Rémi Louf 2019-12-05 19:09:47 +01:00 committed by Julien Chaumond
parent 693606a75c
commit 3a9a9f7861
2 changed files with 8 additions and 5 deletions

View File

@ -31,9 +31,7 @@ Batch = namedtuple(
def evaluate(args):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = bertabs = BertAbs.from_pretrained(
"bertabs-finetuned-{}".format(args.finetuned_model)
)
model = bertabs = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
bertabs.to(args.device)
bertabs.eval()
@ -195,8 +193,8 @@ def main():
"--summaries_output_dir",
default=None,
type=str,
required=True,
help="The folder in wich the summaries should be written.",
required=False,
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
)
# EVALUATION options
parser.add_argument(
@ -242,6 +240,9 @@ def main():
args = parser.parse_args()
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
if not args.summaries_output_dir:
args.summaries_output_dir = args.documents_dir
if not documents_dir_is_valid(args.documents_dir):
raise FileNotFoundError(
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."

View File

@ -39,6 +39,8 @@ class SummarizationDataset(Dataset):
self.documents = []
story_filenames_list = os.listdir(path)
for story_filename in story_filenames_list:
if "summary" in story_filename:
continue
path_to_story = os.path.join(path, story_filename)
if not os.path.isfile(path_to_story):
continue