fill_mask helper

This commit is contained in:
Julien Chaumond 2020-01-18 00:29:46 -05:00
parent 83446a88d9
commit cacc17b884
2 changed files with 22 additions and 1 deletions

View File

@ -130,7 +130,7 @@ if is_sklearn_available():
# Modeling
if is_torch_available():
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, fill_mask
from .modeling_auto import (
AutoModel,
AutoModelForPreTraining,

View File

@ -18,6 +18,7 @@
import logging
import os
from typing import List, Tuple
import torch
from torch import nn
@ -34,6 +35,7 @@ from .file_utils import (
hf_bucket_url,
is_remote_url,
)
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
@ -1504,3 +1506,22 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
def fill_mask(
masked_input: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, topk=5
) -> List[Tuple[str, float]]:
"""
Predict a masked token and return list of most probable filled sequences, with probabilities
"""
tokens = tokenizer.encode(masked_input, add_special_tokens=True)
input_ids = torch.tensor([tokens])
masked_index = (input_ids == tokenizer.mask_token_id).squeeze().nonzero().item()
logits = model(input_ids)[0][0, masked_index, :]
probs = logits.softmax(dim=0)
values, predictions = probs.topk(topk)
results = []
for v, p in zip(values.tolist(), predictions.tolist()):
tokens[masked_index] = p
results += (tokenizer.decode(tokens), v)
return results