add overwrite - fix ner decoding

This commit is contained in:
thomwolf 2019-12-20 21:47:15 +01:00
parent f79a7dc661
commit 4775ec354b
2 changed files with 20 additions and 14 deletions

View File

@ -32,7 +32,8 @@ def run_command_factory(args):
reader = PipelineDataFormat.from_str(format=format,
output_path=args.output,
input_path=args.input,
column=args.column if args.column else nlp.default_input_names)
column=args.column if args.column else nlp.default_input_names,
overwrite=args.overwrite)
return RunCommand(nlp, reader)
@ -54,6 +55,7 @@ class RunCommand(BaseTransformersCLICommand):
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
run_parser.add_argument('--overwrite', action='store_true', help='Allow overwriting the output file.')
run_parser.set_defaults(func=run_command_factory)
def run(self):
@ -61,6 +63,7 @@ class RunCommand(BaseTransformersCLICommand):
for entry in self._reader:
output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
print(output)
if isinstance(output, dict):
outputs.append(output)
else:
@ -68,10 +71,10 @@ class RunCommand(BaseTransformersCLICommand):
# Saving data
if self._nlp.binary_output:
binary_path = self._reader.save_binary(output)
binary_path = self._reader.save_binary(outputs)
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
else:
self._reader.save(output)
self._reader.save(outputs)

View File

@ -107,7 +107,7 @@ class PipelineDataFormat:
"""
SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
self.output_path = output_path
self.input_path = input_path
self.column = column.split(',') if column is not None else ['']
@ -116,7 +116,7 @@ class PipelineDataFormat:
if self.is_multi_columns:
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
if output_path is not None:
if output_path is not None and not overwrite:
if exists(abspath(self.output_path)):
raise OSError('{} already exists on disk'.format(self.output_path))
@ -152,25 +152,26 @@ class PipelineDataFormat:
return binary_path
@staticmethod
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
if format == 'json':
return JsonPipelineDataFormat(output_path, input_path, column)
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
elif format == 'csv':
return CsvPipelineDataFormat(output_path, input_path, column)
return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
elif format == 'pipe':
return PipedPipelineDataFormat(output_path, input_path, column)
return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
else:
raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format))
class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
super().__init__(output_path, input_path, column)
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
super().__init__(output_path, input_path, column, overwrite=overwrite)
def __iter__(self):
with open(self.input_path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
print(row, self.column)
if self.is_multi_columns:
yield {k: row[c] for k, c in self.column}
else:
@ -185,8 +186,8 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
super().__init__(output_path, input_path, column)
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
super().__init__(output_path, input_path, column, overwrite=overwrite)
with open(input_path, 'r') as f:
self._entries = json.load(f)
@ -460,6 +461,8 @@ class NerPipeline(Pipeline):
Named Entity Recognition pipeline using ModelForTokenClassification head.
"""
default_input_names = 'sequences'
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1,
@ -504,7 +507,7 @@ class NerPipeline(Pipeline):
for idx, label_idx in enumerate(labels_idx):
if self.model.config.id2label[label_idx] not in self.ignore_labels:
answer += [{
'word': self.tokenizer.decode(int(input_ids[idx])),
'word': self.tokenizer.decode([int(input_ids[idx])]),
'score': score[idx][label_idx].item(),
'entity': self.model.config.id2label[label_idx]
}]