mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
add overwrite - fix ner decoding
This commit is contained in:
parent
f79a7dc661
commit
4775ec354b
@ -32,7 +32,8 @@ def run_command_factory(args):
|
||||
reader = PipelineDataFormat.from_str(format=format,
|
||||
output_path=args.output,
|
||||
input_path=args.input,
|
||||
column=args.column if args.column else nlp.default_input_names)
|
||||
column=args.column if args.column else nlp.default_input_names,
|
||||
overwrite=args.overwrite)
|
||||
return RunCommand(nlp, reader)
|
||||
|
||||
|
||||
@ -54,6 +55,7 @@ class RunCommand(BaseTransformersCLICommand):
|
||||
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
|
||||
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
|
||||
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
|
||||
run_parser.add_argument('--overwrite', action='store_true', help='Allow overwriting the output file.')
|
||||
run_parser.set_defaults(func=run_command_factory)
|
||||
|
||||
def run(self):
|
||||
@ -61,6 +63,7 @@ class RunCommand(BaseTransformersCLICommand):
|
||||
|
||||
for entry in self._reader:
|
||||
output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
|
||||
print(output)
|
||||
if isinstance(output, dict):
|
||||
outputs.append(output)
|
||||
else:
|
||||
@ -68,10 +71,10 @@ class RunCommand(BaseTransformersCLICommand):
|
||||
|
||||
# Saving data
|
||||
if self._nlp.binary_output:
|
||||
binary_path = self._reader.save_binary(output)
|
||||
binary_path = self._reader.save_binary(outputs)
|
||||
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
|
||||
else:
|
||||
self._reader.save(output)
|
||||
self._reader.save(outputs)
|
||||
|
||||
|
||||
|
||||
|
@ -107,7 +107,7 @@ class PipelineDataFormat:
|
||||
"""
|
||||
SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
|
||||
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
||||
self.output_path = output_path
|
||||
self.input_path = input_path
|
||||
self.column = column.split(',') if column is not None else ['']
|
||||
@ -116,7 +116,7 @@ class PipelineDataFormat:
|
||||
if self.is_multi_columns:
|
||||
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
|
||||
|
||||
if output_path is not None:
|
||||
if output_path is not None and not overwrite:
|
||||
if exists(abspath(self.output_path)):
|
||||
raise OSError('{} already exists on disk'.format(self.output_path))
|
||||
|
||||
@ -152,25 +152,26 @@ class PipelineDataFormat:
|
||||
return binary_path
|
||||
|
||||
@staticmethod
|
||||
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
|
||||
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
||||
if format == 'json':
|
||||
return JsonPipelineDataFormat(output_path, input_path, column)
|
||||
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
||||
elif format == 'csv':
|
||||
return CsvPipelineDataFormat(output_path, input_path, column)
|
||||
return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
||||
elif format == 'pipe':
|
||||
return PipedPipelineDataFormat(output_path, input_path, column)
|
||||
return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
||||
else:
|
||||
raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format))
|
||||
|
||||
|
||||
class CsvPipelineDataFormat(PipelineDataFormat):
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
|
||||
super().__init__(output_path, input_path, column)
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
||||
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
||||
|
||||
def __iter__(self):
|
||||
with open(self.input_path, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
print(row, self.column)
|
||||
if self.is_multi_columns:
|
||||
yield {k: row[c] for k, c in self.column}
|
||||
else:
|
||||
@ -185,8 +186,8 @@ class CsvPipelineDataFormat(PipelineDataFormat):
|
||||
|
||||
|
||||
class JsonPipelineDataFormat(PipelineDataFormat):
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
|
||||
super().__init__(output_path, input_path, column)
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
||||
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
||||
|
||||
with open(input_path, 'r') as f:
|
||||
self._entries = json.load(f)
|
||||
@ -460,6 +461,8 @@ class NerPipeline(Pipeline):
|
||||
Named Entity Recognition pipeline using ModelForTokenClassification head.
|
||||
"""
|
||||
|
||||
default_input_names = 'sequences'
|
||||
|
||||
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: ModelCard = None, framework: Optional[str] = None,
|
||||
args_parser: ArgumentHandler = None, device: int = -1,
|
||||
@ -504,7 +507,7 @@ class NerPipeline(Pipeline):
|
||||
for idx, label_idx in enumerate(labels_idx):
|
||||
if self.model.config.id2label[label_idx] not in self.ignore_labels:
|
||||
answer += [{
|
||||
'word': self.tokenizer.decode(int(input_ids[idx])),
|
||||
'word': self.tokenizer.decode([int(input_ids[idx])]),
|
||||
'score': score[idx][label_idx].item(),
|
||||
'entity': self.model.config.id2label[label_idx]
|
||||
}]
|
||||
|
Loading…
Reference in New Issue
Block a user