mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fill_mask helper
This commit is contained in:
parent
83446a88d9
commit
cacc17b884
@ -130,7 +130,7 @@ if is_sklearn_available():
|
||||
|
||||
# Modeling
|
||||
if is_torch_available():
|
||||
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
|
||||
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, fill_mask
|
||||
from .modeling_auto import (
|
||||
AutoModel,
|
||||
AutoModelForPreTraining,
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -34,6 +35,7 @@ from .file_utils import (
|
||||
hf_bucket_url,
|
||||
is_remote_url,
|
||||
)
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1504,3 +1506,22 @@ def prune_layer(layer, index, dim=None):
|
||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||
else:
|
||||
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
|
||||
|
||||
|
||||
def fill_mask(
|
||||
masked_input: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, topk=5
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Predict a masked token and return list of most probable filled sequences, with probabilities
|
||||
"""
|
||||
tokens = tokenizer.encode(masked_input, add_special_tokens=True)
|
||||
input_ids = torch.tensor([tokens])
|
||||
masked_index = (input_ids == tokenizer.mask_token_id).squeeze().nonzero().item()
|
||||
logits = model(input_ids)[0][0, masked_index, :]
|
||||
probs = logits.softmax(dim=0)
|
||||
values, predictions = probs.topk(topk)
|
||||
results = []
|
||||
for v, p in zip(values.tolist(), predictions.tolist()):
|
||||
tokens[masked_index] = p
|
||||
results += (tokenizer.decode(tokens), v)
|
||||
return results
|
||||
|
Loading…
Reference in New Issue
Block a user