mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-12 17:20:03 +06:00

* Reorganize example folder * Continue reorganization * Change requirements for tests * Final cleanup * Finish regroup with tests all passing * Copyright * Requirements and readme * Make a full link for the documentation * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add symlink * Reorg again * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Adapt title * Update to new strucutre * Remove test * Update READMEs Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
47 lines
1.9 KiB
Python
47 lines
1.9 KiB
Python
import torch
|
|
|
|
from transformers import CamembertForMaskedLM, CamembertTokenizer
|
|
|
|
|
|
def fill_mask(masked_input, model, tokenizer, topk=5):
|
|
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
|
|
assert masked_input.count("<mask>") == 1
|
|
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
|
|
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
|
|
logits = logits[0, masked_index, :]
|
|
prob = logits.softmax(dim=0)
|
|
values, indices = prob.topk(k=topk, dim=0)
|
|
topk_predicted_token_bpe = " ".join(
|
|
[tokenizer.convert_ids_to_tokens(indices[i].item()) for i in range(len(indices))]
|
|
)
|
|
masked_token = tokenizer.mask_token
|
|
topk_filled_outputs = []
|
|
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(" ")):
|
|
predicted_token = predicted_token_bpe.replace("\u2581", " ")
|
|
if " {0}".format(masked_token) in masked_input:
|
|
topk_filled_outputs.append(
|
|
(
|
|
masked_input.replace(" {0}".format(masked_token), predicted_token),
|
|
values[index].item(),
|
|
predicted_token,
|
|
)
|
|
)
|
|
else:
|
|
topk_filled_outputs.append(
|
|
(
|
|
masked_input.replace(masked_token, predicted_token),
|
|
values[index].item(),
|
|
predicted_token,
|
|
)
|
|
)
|
|
return topk_filled_outputs
|
|
|
|
|
|
tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
|
|
model = CamembertForMaskedLM.from_pretrained("camembert-base")
|
|
model.eval()
|
|
|
|
masked_input = "Le camembert est <mask> :)"
|
|
print(fill_mask(masked_input, model, tokenizer, topk=3))
|