mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add demo_camembert.py
This commit is contained in:
parent
14b3aa3b3c
commit
6e72fd094c
59
examples/demo_camembert.py
Normal file
59
examples/demo_camembert.py
Normal file
@ -0,0 +1,59 @@
|
||||
from pathlib import Path
|
||||
import tarfile
|
||||
import urllib.request
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.tokenization_camembert import CamembertTokenizer
|
||||
from transformers.modeling_roberta import RobertaForMaskedLM
|
||||
|
||||
|
||||
def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
|
||||
assert masked_input.count('<mask>') == 1
|
||||
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
|
||||
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
|
||||
logits = logits[0, masked_index, :]
|
||||
prob = logits.softmax(dim=0)
|
||||
values, indices = prob.topk(k=topk, dim=0)
|
||||
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item())
|
||||
for i in range(len(indices))])
|
||||
masked_token = tokenizer.mask_token
|
||||
topk_filled_outputs = []
|
||||
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')):
|
||||
predicted_token = predicted_token_bpe.replace('\u2581', ' ')
|
||||
if " {0}".format(masked_token) in masked_input:
|
||||
topk_filled_outputs.append((
|
||||
masked_input.replace(
|
||||
' {0}'.format(masked_token), predicted_token
|
||||
),
|
||||
values[index].item(),
|
||||
predicted_token,
|
||||
))
|
||||
else:
|
||||
topk_filled_outputs.append((
|
||||
masked_input.replace(masked_token, predicted_token),
|
||||
values[index].item(),
|
||||
predicted_token,
|
||||
))
|
||||
return topk_filled_outputs
|
||||
|
||||
|
||||
model_path = Path('camembert.v0.pytorch')
|
||||
if not model_path.exists():
|
||||
compressed_path = model_path.with_suffix('.tar.gz')
|
||||
url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
|
||||
print('Downloading model...')
|
||||
urllib.request.urlretrieve(url, compressed_path)
|
||||
print('Extracting model...')
|
||||
with tarfile.open(compressed_path) as f:
|
||||
f.extractall(model_path.parent)
|
||||
assert model_path.exists()
|
||||
tokenizer_path = model_path / 'sentencepiece.bpe.model'
|
||||
tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path)
|
||||
model = RobertaForMaskedLM.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
masked_input = "Le camembert est <mask> :)"
|
||||
print(fill_mask(masked_input, model, tokenizer, topk=3))
|
Loading…
Reference in New Issue
Block a user