Merge pull request #2244 from huggingface/fix-tok-pipe

Fix Camembert and XLM-R `decode` method- Fix NER pipeline alignement
This commit is contained in:
Thomas Wolf 2019-12-20 22:10:39 +01:00 committed by GitHub
commit d0f8b9a978
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 34 deletions

View File

@ -32,7 +32,8 @@ def run_command_factory(args):
reader = PipelineDataFormat.from_str(format=format,
output_path=args.output,
input_path=args.input,
column=args.column if args.column else nlp.default_input_names)
column=args.column if args.column else nlp.default_input_names,
overwrite=args.overwrite)
return RunCommand(nlp, reader)
@ -54,6 +55,7 @@ class RunCommand(BaseTransformersCLICommand):
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
run_parser.add_argument('--overwrite', action='store_true', help='Allow overwriting the output file.')
run_parser.set_defaults(func=run_command_factory)
def run(self):
@ -68,10 +70,10 @@ class RunCommand(BaseTransformersCLICommand):
# Saving data
if self._nlp.binary_output:
binary_path = self._reader.save_binary(output)
binary_path = self._reader.save_binary(outputs)
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
else:
self._reader.save(output)
self._reader.save(outputs)

View File

@ -287,7 +287,8 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
return
content_length = response.headers.get('Content-Length')
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size, desc="Downloading")
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size,
desc="Downloading", disable=bool(logger.level<=logging.INFO))
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))

View File

@ -107,7 +107,7 @@ class PipelineDataFormat:
"""
SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
self.output_path = output_path
self.input_path = input_path
self.column = column.split(',') if column is not None else ['']
@ -116,7 +116,7 @@ class PipelineDataFormat:
if self.is_multi_columns:
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
if output_path is not None:
if output_path is not None and not overwrite:
if exists(abspath(self.output_path)):
raise OSError('{} already exists on disk'.format(self.output_path))
@ -152,20 +152,20 @@ class PipelineDataFormat:
return binary_path
@staticmethod
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
if format == 'json':
return JsonPipelineDataFormat(output_path, input_path, column)
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
elif format == 'csv':
return CsvPipelineDataFormat(output_path, input_path, column)
return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
elif format == 'pipe':
return PipedPipelineDataFormat(output_path, input_path, column)
return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
else:
raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format))
class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
super().__init__(output_path, input_path, column)
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
super().__init__(output_path, input_path, column, overwrite=overwrite)
def __iter__(self):
with open(self.input_path, 'r') as f:
@ -185,8 +185,8 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
super().__init__(output_path, input_path, column)
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
super().__init__(output_path, input_path, column, overwrite=overwrite)
with open(input_path, 'r') as f:
self._entries = json.load(f)
@ -460,10 +460,12 @@ class NerPipeline(Pipeline):
Named Entity Recognition pipeline using ModelForTokenClassification head.
"""
default_input_names = 'sequences'
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False):
binary_output: bool = False, ignore_labels=['O']):
super().__init__(model=model,
tokenizer=tokenizer,
modelcard=modelcard,
@ -473,17 +475,12 @@ class NerPipeline(Pipeline):
binary_output=binary_output)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels
def __call__(self, *texts, **kwargs):
inputs, answers = self._args_parser(*texts, **kwargs), []
for sentence in inputs:
# Ugly token to word idx mapping (for now)
token_to_word, words = [], self._basic_tokenizer.tokenize(sentence)
for i, w in enumerate(words):
tokens = self.tokenizer.tokenize(w)
token_to_word += [i] * len(tokens)
# Manage correct placement of the tensors
with self.device_placement():
@ -496,30 +493,28 @@ class NerPipeline(Pipeline):
# Forward
if self.framework == 'tf':
entities = self.model(tokens)[0][0].numpy()
input_ids = tokens['input_ids'].numpy()[0]
else:
with torch.no_grad():
entities = self.model(**tokens)[0][0].cpu().numpy()
input_ids = tokens['input_ids'].cpu().numpy()[0]
# Normalize scores
answer, token_start = [], 1
for idx, word in groupby(token_to_word):
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
labels_idx = score.argmax(axis=-1)
# Sum log prob over token, then normalize across labels
score = np.exp(entities[token_start]) / np.exp(entities[token_start]).sum(-1, keepdims=True)
label_idx = score.argmax()
if label_idx > 0:
answer = []
for idx, label_idx in enumerate(labels_idx):
if self.model.config.id2label[label_idx] not in self.ignore_labels:
answer += [{
'word': words[idx],
'score': score[label_idx].item(),
'word': self.tokenizer.decode([int(input_ids[idx])]),
'score': score[idx][label_idx].item(),
'entity': self.model.config.id2label[label_idx]
}]
# Update token start
token_start += len(list(word))
# Append
answers += [answer]
if len(answers) == 1:
return answers[0]
return answers

View File

@ -22,6 +22,7 @@ from shutil import copyfile
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__)
@ -145,6 +146,11 @@ class CamembertTokenizer(PreTrainedTokenizer):
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.

View File

@ -22,6 +22,7 @@ from shutil import copyfile
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__)
@ -161,6 +162,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.