Add demo_camembert.py

This commit is contained in:
Louis MARTIN 2019-11-08 17:09:48 -08:00 committed by Julien Chaumond
parent 14b3aa3b3c
commit 6e72fd094c

View File

@ -0,0 +1,59 @@
from pathlib import Path
import tarfile
import urllib.request
import torch
from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_roberta import RobertaForMaskedLM
def fill_mask(masked_input, model, tokenizer, topk=5):
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
assert masked_input.count('<mask>') == 1
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
logits = logits[0, masked_index, :]
prob = logits.softmax(dim=0)
values, indices = prob.topk(k=topk, dim=0)
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item())
for i in range(len(indices))])
masked_token = tokenizer.mask_token
topk_filled_outputs = []
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')):
predicted_token = predicted_token_bpe.replace('\u2581', ' ')
if " {0}".format(masked_token) in masked_input:
topk_filled_outputs.append((
masked_input.replace(
' {0}'.format(masked_token), predicted_token
),
values[index].item(),
predicted_token,
))
else:
topk_filled_outputs.append((
masked_input.replace(masked_token, predicted_token),
values[index].item(),
predicted_token,
))
return topk_filled_outputs
model_path = Path('camembert.v0.pytorch')
if not model_path.exists():
compressed_path = model_path.with_suffix('.tar.gz')
url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
print('Downloading model...')
urllib.request.urlretrieve(url, compressed_path)
print('Extracting model...')
with tarfile.open(compressed_path) as f:
f.extractall(model_path.parent)
assert model_path.exists()
tokenizer_path = model_path / 'sentencepiece.bpe.model'
tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path)
model = RobertaForMaskedLM.from_pretrained(model_path)
model.eval()
masked_input = "Le camembert est <mask> :)"
print(fill_mask(masked_input, model, tokenizer, topk=3))