further cleanup

This commit is contained in:
thomwolf 2019-12-18 11:50:54 +01:00
parent 8e5587fb79
commit 3d2096f516
6 changed files with 58 additions and 76 deletions

View File

@ -91,7 +91,7 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
def prepare_xlm_input(args, model, tokenizer, prompt_text):
kwargs = {"language": None, "mask_token": None}
kwargs = {"language": None, "mask_token_id": None}
# Set the language
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
@ -112,7 +112,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
# XLM masked-language modeling (MLM) models need masked token
is_xlm_mlm = "mlm" in args.model_name_or_path
if is_xlm_mlm:
kwargs["mask_token"] = tokenizer.mask_token_id
kwargs["mask_token_id"] = tokenizer.mask_token_id
return prompt_text, kwargs
@ -204,14 +204,13 @@ def main():
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
output_sequences = model.decode(
prompt_ids=encoded_prompt,
output_sequences = model.generate(
intput_ids=encoded_prompt,
length=args.length,
temperature=args.temperature,
k=args.k,
p=args.p,
top_k=args.k,
top_p=args.p,
repetition_penalty=args.repetition_penalty,
device=args.device,
**model_kwargs,
)

View File

@ -113,6 +113,8 @@ class XLMConfig(PretrainedConfig):
summary_first_dropout=0.1,
start_n_top=5,
end_n_top=5,
mask_token_id = 0,
lang_id = 0,
**kwargs):
"""Constructs XLMConfig.
"""
@ -156,6 +158,8 @@ class XLMConfig(PretrainedConfig):
self.summary_first_dropout = summary_first_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
self.mask_token_id = mask_token_id
self.lang_id = lang_id
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")

View File

@ -488,7 +488,7 @@ class PreTrainedModel(nn.Module):
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
length_penalty=None, num_return_sequences=None, **kwargs):
length_penalty=None, num_return_sequences=None, **model_kwargs):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
@ -575,11 +575,13 @@ class PreTrainedModel(nn.Module):
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size,
length_penalty, num_beams, vocab_size)
length_penalty, num_beams, vocab_size,
**model_kwargs)
else:
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size)
pad_token_id, eos_token_ids, effective_batch_size,
**model_kwargs)
if num_return_sequences != 1:
output = output.view(batch_size, num_return_sequences, -1)
@ -587,7 +589,8 @@ class PreTrainedModel(nn.Module):
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size):
pad_token_id, eos_token_ids, batch_size,
**model_kwargs):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
@ -598,7 +601,7 @@ class PreTrainedModel(nn.Module):
pasts = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
@ -640,7 +643,8 @@ class PreTrainedModel(nn.Module):
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
length_penalty, num_beams, vocab_size):
length_penalty, num_beams, vocab_size,
**model_kwargs):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
@ -662,7 +666,7 @@ class PreTrainedModel(nn.Module):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)

View File

@ -639,6 +639,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self):
return self.pred_layer.proj
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id
lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, mask_token], dim=1)
if lang_id is not None:
langs = torch.full_like(input_ids, lang_id)
else:
langs = None
return {"input_ids": input_ids, "langs": langs}
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
@ -657,33 +669,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
return outputs
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
mask_token = model_kwargs.pop("mask_token", None)
language = model_kwargs.pop("language", None)
input_ids = self._append_mask_token(input_ids, mask_token)
langs = self._create_language_embeddings(input_ids, language)
arguments = {"input_ids": input_ids, "langs": langs}
arguments.update(model_kwargs)
return arguments
@staticmethod
def _append_mask_token(sequence, mask_token_id):
""" Append a [MASK] token at the end of the sequence that the MLM model
is going to try to predict.
"""
if mask_token_id is not None:
tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long)
return torch.cat((sequence, tokens_to_append), dim=1)
return sequence
@staticmethod
def _create_language_embeddings(sequence, language):
if language is not None:
return torch.tensor([language] * sequence.shape[1]).view(1, -1)
return None
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,

View File

@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
# Add dummy token at the end (no attention on this one)
dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
# Build permutation mask so that previous tokens don't see last token
perm_mask = torch.zeros(
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
dtype=torch.float, device=input_ids.device
)
perm_mask[:, :, -1] = 1.0
# We'll only predict the last token
target_mapping = torch.zeros(
(input_ids.shape[0], 1, input_ids.shape[1]),
dtype=torch.float, device=input_ids.device
)
target_mapping[0, 0, -1] = 1.0
return {"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping
}
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
@ -972,40 +996,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
input_ids = self._add_dummy_token(input_ids)
perm_mask = self._create_perm_mask(input_ids)
target_mapping = self._create_target_mapping(input_ids)
arguments = {
"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
}
return arguments
@staticmethod
def _add_dummy_token(sequence):
dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long)
return torch.cat((sequence, dummy), dim=1)
@staticmethod
def _create_perm_mask(sequence):
mask = torch.zeros(
(sequence.shape[0], sequence.shape[1], sequence.shape[1]),
dtype=torch.float,
)
mask[:, :, -1] = 1.0 # Previous tokens don't see last token
return mask
@staticmethod
def _create_target_mapping(sequence):
target_mapping = torch.zeros(
(sequence.shape[0], 1, sequence.shape[1]),
dtype=torch.float,
)
target_mapping[0, 0, -1] = 1.0 # predict last token
return target_mapping
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,