further cleanup

This commit is contained in:
thomwolf 2019-12-18 11:50:54 +01:00
parent 8e5587fb79
commit 3d2096f516
6 changed files with 58 additions and 76 deletions

View File

@ -91,7 +91,7 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
def prepare_xlm_input(args, model, tokenizer, prompt_text):
kwargs = {"language": None, "mask_token": None}
kwargs = {"language": None, "mask_token_id": None}
# Set the language
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
@ -112,7 +112,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
# XLM masked-language modeling (MLM) models need masked token
is_xlm_mlm = "mlm" in args.model_name_or_path
if is_xlm_mlm:
kwargs["mask_token"] = tokenizer.mask_token_id
kwargs["mask_token_id"] = tokenizer.mask_token_id
return prompt_text, kwargs
@ -204,14 +204,13 @@ def main():
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
output_sequences = model.decode(
prompt_ids=encoded_prompt,
output_sequences = model.generate(
intput_ids=encoded_prompt,
length=args.length,
temperature=args.temperature,
k=args.k,
p=args.p,
top_k=args.k,
top_p=args.p,
repetition_penalty=args.repetition_penalty,
device=args.device,
**model_kwargs,
)

View File

@ -113,6 +113,8 @@ class XLMConfig(PretrainedConfig):
summary_first_dropout=0.1,
start_n_top=5,
end_n_top=5,
mask_token_id = 0,
lang_id = 0,
**kwargs):
"""Constructs XLMConfig.
"""
@ -156,6 +158,8 @@ class XLMConfig(PretrainedConfig):
self.summary_first_dropout = summary_first_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
self.mask_token_id = mask_token_id
self.lang_id = lang_id
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")

View File

@ -488,7 +488,7 @@ class PreTrainedModel(nn.Module):
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
length_penalty=None, num_return_sequences=None, **kwargs):
length_penalty=None, num_return_sequences=None, **model_kwargs):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
@ -575,11 +575,13 @@ class PreTrainedModel(nn.Module):
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size,
length_penalty, num_beams, vocab_size)
length_penalty, num_beams, vocab_size,
**model_kwargs)
else:
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size)
pad_token_id, eos_token_ids, effective_batch_size,
**model_kwargs)
if num_return_sequences != 1:
output = output.view(batch_size, num_return_sequences, -1)
@ -587,7 +589,8 @@ class PreTrainedModel(nn.Module):
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size):
pad_token_id, eos_token_ids, batch_size,
**model_kwargs):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
@ -598,7 +601,7 @@ class PreTrainedModel(nn.Module):
pasts = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
@ -640,7 +643,8 @@ class PreTrainedModel(nn.Module):
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
length_penalty, num_beams, vocab_size):
length_penalty, num_beams, vocab_size,
**model_kwargs):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
@ -662,7 +666,7 @@ class PreTrainedModel(nn.Module):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)

View File

@ -639,6 +639,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self):
return self.pred_layer.proj
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id
lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, mask_token], dim=1)
if lang_id is not None:
langs = torch.full_like(input_ids, lang_id)
else:
langs = None
return {"input_ids": input_ids, "langs": langs}
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
@ -657,33 +669,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
return outputs
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
mask_token = model_kwargs.pop("mask_token", None)
language = model_kwargs.pop("language", None)
input_ids = self._append_mask_token(input_ids, mask_token)
langs = self._create_language_embeddings(input_ids, language)
arguments = {"input_ids": input_ids, "langs": langs}
arguments.update(model_kwargs)
return arguments
@staticmethod
def _append_mask_token(sequence, mask_token_id):
""" Append a [MASK] token at the end of the sequence that the MLM model
is going to try to predict.
"""
if mask_token_id is not None:
tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long)
return torch.cat((sequence, tokens_to_append), dim=1)
return sequence
@staticmethod
def _create_language_embeddings(sequence, language):
if language is not None:
return torch.tensor([language] * sequence.shape[1]).view(1, -1)
return None
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,

View File

@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
# Add dummy token at the end (no attention on this one)
dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
# Build permutation mask so that previous tokens don't see last token
perm_mask = torch.zeros(
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
dtype=torch.float, device=input_ids.device
)
perm_mask[:, :, -1] = 1.0
# We'll only predict the last token
target_mapping = torch.zeros(
(input_ids.shape[0], 1, input_ids.shape[1]),
dtype=torch.float, device=input_ids.device
)
target_mapping[0, 0, -1] = 1.0
return {"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping
}
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
@ -972,40 +996,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
input_ids = self._add_dummy_token(input_ids)
perm_mask = self._create_perm_mask(input_ids)
target_mapping = self._create_target_mapping(input_ids)
arguments = {
"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
}
return arguments
@staticmethod
def _add_dummy_token(sequence):
dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long)
return torch.cat((sequence, dummy), dim=1)
@staticmethod
def _create_perm_mask(sequence):
mask = torch.zeros(
(sequence.shape[0], sequence.shape[1], sequence.shape[1]),
dtype=torch.float,
)
mask[:, :, -1] = 1.0 # Previous tokens don't see last token
return mask
@staticmethod
def _create_target_mapping(sequence):
target_mapping = torch.zeros(
(sequence.shape[0], 1, sequence.shape[1]),
dtype=torch.float,
)
target_mapping[0, 0, -1] = 1.0 # predict last token
return target_mapping
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,

View File

@ -761,7 +761,7 @@ class PreTrainedTokenizer(object):
padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
The tokenizer padding sides are handled by the following strings:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
- 'right': pads on the right of the sequences
Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.