mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #2244 from huggingface/fix-tok-pipe
Fix Camembert and XLM-R `decode` method- Fix NER pipeline alignement
This commit is contained in:
commit
d0f8b9a978
@ -32,7 +32,8 @@ def run_command_factory(args):
|
||||
reader = PipelineDataFormat.from_str(format=format,
|
||||
output_path=args.output,
|
||||
input_path=args.input,
|
||||
column=args.column if args.column else nlp.default_input_names)
|
||||
column=args.column if args.column else nlp.default_input_names,
|
||||
overwrite=args.overwrite)
|
||||
return RunCommand(nlp, reader)
|
||||
|
||||
|
||||
@ -54,6 +55,7 @@ class RunCommand(BaseTransformersCLICommand):
|
||||
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
|
||||
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
|
||||
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
|
||||
run_parser.add_argument('--overwrite', action='store_true', help='Allow overwriting the output file.')
|
||||
run_parser.set_defaults(func=run_command_factory)
|
||||
|
||||
def run(self):
|
||||
@ -68,10 +70,10 @@ class RunCommand(BaseTransformersCLICommand):
|
||||
|
||||
# Saving data
|
||||
if self._nlp.binary_output:
|
||||
binary_path = self._reader.save_binary(output)
|
||||
binary_path = self._reader.save_binary(outputs)
|
||||
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
|
||||
else:
|
||||
self._reader.save(output)
|
||||
self._reader.save(outputs)
|
||||
|
||||
|
||||
|
||||
|
@ -287,7 +287,8 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
||||
return
|
||||
content_length = response.headers.get('Content-Length')
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size, desc="Downloading")
|
||||
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size,
|
||||
desc="Downloading", disable=bool(logger.level<=logging.INFO))
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
|
@ -107,7 +107,7 @@ class PipelineDataFormat:
|
||||
"""
|
||||
SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
|
||||
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
||||
self.output_path = output_path
|
||||
self.input_path = input_path
|
||||
self.column = column.split(',') if column is not None else ['']
|
||||
@ -116,7 +116,7 @@ class PipelineDataFormat:
|
||||
if self.is_multi_columns:
|
||||
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
|
||||
|
||||
if output_path is not None:
|
||||
if output_path is not None and not overwrite:
|
||||
if exists(abspath(self.output_path)):
|
||||
raise OSError('{} already exists on disk'.format(self.output_path))
|
||||
|
||||
@ -152,20 +152,20 @@ class PipelineDataFormat:
|
||||
return binary_path
|
||||
|
||||
@staticmethod
|
||||
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
|
||||
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
||||
if format == 'json':
|
||||
return JsonPipelineDataFormat(output_path, input_path, column)
|
||||
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
||||
elif format == 'csv':
|
||||
return CsvPipelineDataFormat(output_path, input_path, column)
|
||||
return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
||||
elif format == 'pipe':
|
||||
return PipedPipelineDataFormat(output_path, input_path, column)
|
||||
return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
||||
else:
|
||||
raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format))
|
||||
|
||||
|
||||
class CsvPipelineDataFormat(PipelineDataFormat):
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
|
||||
super().__init__(output_path, input_path, column)
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
||||
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
||||
|
||||
def __iter__(self):
|
||||
with open(self.input_path, 'r') as f:
|
||||
@ -185,8 +185,8 @@ class CsvPipelineDataFormat(PipelineDataFormat):
|
||||
|
||||
|
||||
class JsonPipelineDataFormat(PipelineDataFormat):
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
|
||||
super().__init__(output_path, input_path, column)
|
||||
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
|
||||
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
||||
|
||||
with open(input_path, 'r') as f:
|
||||
self._entries = json.load(f)
|
||||
@ -460,10 +460,12 @@ class NerPipeline(Pipeline):
|
||||
Named Entity Recognition pipeline using ModelForTokenClassification head.
|
||||
"""
|
||||
|
||||
default_input_names = 'sequences'
|
||||
|
||||
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: ModelCard = None, framework: Optional[str] = None,
|
||||
args_parser: ArgumentHandler = None, device: int = -1,
|
||||
binary_output: bool = False):
|
||||
binary_output: bool = False, ignore_labels=['O']):
|
||||
super().__init__(model=model,
|
||||
tokenizer=tokenizer,
|
||||
modelcard=modelcard,
|
||||
@ -473,17 +475,12 @@ class NerPipeline(Pipeline):
|
||||
binary_output=binary_output)
|
||||
|
||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||
self.ignore_labels = ignore_labels
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
inputs, answers = self._args_parser(*texts, **kwargs), []
|
||||
for sentence in inputs:
|
||||
|
||||
# Ugly token to word idx mapping (for now)
|
||||
token_to_word, words = [], self._basic_tokenizer.tokenize(sentence)
|
||||
for i, w in enumerate(words):
|
||||
tokens = self.tokenizer.tokenize(w)
|
||||
token_to_word += [i] * len(tokens)
|
||||
|
||||
# Manage correct placement of the tensors
|
||||
with self.device_placement():
|
||||
|
||||
@ -496,30 +493,28 @@ class NerPipeline(Pipeline):
|
||||
# Forward
|
||||
if self.framework == 'tf':
|
||||
entities = self.model(tokens)[0][0].numpy()
|
||||
input_ids = tokens['input_ids'].numpy()[0]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
entities = self.model(**tokens)[0][0].cpu().numpy()
|
||||
input_ids = tokens['input_ids'].cpu().numpy()[0]
|
||||
|
||||
# Normalize scores
|
||||
answer, token_start = [], 1
|
||||
for idx, word in groupby(token_to_word):
|
||||
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
|
||||
labels_idx = score.argmax(axis=-1)
|
||||
|
||||
# Sum log prob over token, then normalize across labels
|
||||
score = np.exp(entities[token_start]) / np.exp(entities[token_start]).sum(-1, keepdims=True)
|
||||
label_idx = score.argmax()
|
||||
|
||||
if label_idx > 0:
|
||||
answer = []
|
||||
for idx, label_idx in enumerate(labels_idx):
|
||||
if self.model.config.id2label[label_idx] not in self.ignore_labels:
|
||||
answer += [{
|
||||
'word': words[idx],
|
||||
'score': score[label_idx].item(),
|
||||
'word': self.tokenizer.decode([int(input_ids[idx])]),
|
||||
'score': score[idx][label_idx].item(),
|
||||
'entity': self.model.config.id2label[label_idx]
|
||||
}]
|
||||
|
||||
# Update token start
|
||||
token_start += len(list(word))
|
||||
|
||||
# Append
|
||||
answers += [answer]
|
||||
if len(answers) == 1:
|
||||
return answers[0]
|
||||
return answers
|
||||
|
||||
|
||||
|
@ -22,6 +22,7 @@ from shutil import copyfile
|
||||
|
||||
import sentencepiece as spm
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -145,6 +146,11 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
|
@ -22,6 +22,7 @@ from shutil import copyfile
|
||||
|
||||
import sentencepiece as spm
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -161,6 +162,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
|
Loading…
Reference in New Issue
Block a user