mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
further cleanup
This commit is contained in:
parent
8e5587fb79
commit
3d2096f516
@ -91,7 +91,7 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
||||
|
||||
|
||||
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
kwargs = {"language": None, "mask_token": None}
|
||||
kwargs = {"language": None, "mask_token_id": None}
|
||||
|
||||
# Set the language
|
||||
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
||||
@ -112,7 +112,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
# XLM masked-language modeling (MLM) models need masked token
|
||||
is_xlm_mlm = "mlm" in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
kwargs["mask_token"] = tokenizer.mask_token_id
|
||||
kwargs["mask_token_id"] = tokenizer.mask_token_id
|
||||
|
||||
return prompt_text, kwargs
|
||||
|
||||
@ -204,14 +204,13 @@ def main():
|
||||
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
|
||||
|
||||
output_sequences = model.decode(
|
||||
prompt_ids=encoded_prompt,
|
||||
output_sequences = model.generate(
|
||||
intput_ids=encoded_prompt,
|
||||
length=args.length,
|
||||
temperature=args.temperature,
|
||||
k=args.k,
|
||||
p=args.p,
|
||||
top_k=args.k,
|
||||
top_p=args.p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
device=args.device,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
|
@ -113,6 +113,8 @@ class XLMConfig(PretrainedConfig):
|
||||
summary_first_dropout=0.1,
|
||||
start_n_top=5,
|
||||
end_n_top=5,
|
||||
mask_token_id = 0,
|
||||
lang_id = 0,
|
||||
**kwargs):
|
||||
"""Constructs XLMConfig.
|
||||
"""
|
||||
@ -156,6 +158,8 @@ class XLMConfig(PretrainedConfig):
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.start_n_top = start_n_top
|
||||
self.end_n_top = end_n_top
|
||||
self.mask_token_id = mask_token_id
|
||||
self.lang_id = lang_id
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
" or the path to a pretrained model config file (str)")
|
||||
|
@ -488,7 +488,7 @@ class PreTrainedModel(nn.Module):
|
||||
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
||||
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
||||
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
|
||||
length_penalty=None, num_return_sequences=None, **kwargs):
|
||||
length_penalty=None, num_return_sequences=None, **model_kwargs):
|
||||
""" Sequence generator for models with a LM head.
|
||||
|
||||
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
@ -575,11 +575,13 @@ class PreTrainedModel(nn.Module):
|
||||
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||
temperature, top_k, top_p, repetition_penalty,
|
||||
pad_token_id, eos_token_ids, effective_batch_size,
|
||||
length_penalty, num_beams, vocab_size)
|
||||
length_penalty, num_beams, vocab_size,
|
||||
**model_kwargs)
|
||||
else:
|
||||
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||
temperature, top_k, top_p, repetition_penalty,
|
||||
pad_token_id, eos_token_ids, effective_batch_size)
|
||||
pad_token_id, eos_token_ids, effective_batch_size,
|
||||
**model_kwargs)
|
||||
|
||||
if num_return_sequences != 1:
|
||||
output = output.view(batch_size, num_return_sequences, -1)
|
||||
@ -587,7 +589,8 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||
temperature, top_k, top_p, repetition_penalty,
|
||||
pad_token_id, eos_token_ids, batch_size):
|
||||
pad_token_id, eos_token_ids, batch_size,
|
||||
**model_kwargs):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
"""
|
||||
@ -598,7 +601,7 @@ class PreTrainedModel(nn.Module):
|
||||
pasts = None
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
@ -640,7 +643,8 @@ class PreTrainedModel(nn.Module):
|
||||
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||
temperature, top_k, top_p, repetition_penalty,
|
||||
pad_token_id, eos_token_ids, batch_size,
|
||||
length_penalty, num_beams, vocab_size):
|
||||
length_penalty, num_beams, vocab_size,
|
||||
**model_kwargs):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
# Expand input to num beams
|
||||
@ -662,7 +666,7 @@ class PreTrainedModel(nn.Module):
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
|
||||
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
||||
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
|
@ -639,6 +639,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.pred_layer.proj
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
||||
mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id
|
||||
lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id
|
||||
|
||||
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
|
||||
input_ids = torch.cat([input_ids, mask_token], dim=1)
|
||||
if lang_id is not None:
|
||||
langs = torch.full_like(input_ids, lang_id)
|
||||
else:
|
||||
langs = None
|
||||
return {"input_ids": input_ids, "langs": langs}
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
transformer_outputs = self.transformer(input_ids,
|
||||
@ -657,33 +669,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
|
||||
return outputs
|
||||
|
||||
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
|
||||
mask_token = model_kwargs.pop("mask_token", None)
|
||||
language = model_kwargs.pop("language", None)
|
||||
input_ids = self._append_mask_token(input_ids, mask_token)
|
||||
langs = self._create_language_embeddings(input_ids, language)
|
||||
arguments = {"input_ids": input_ids, "langs": langs}
|
||||
arguments.update(model_kwargs)
|
||||
|
||||
return arguments
|
||||
|
||||
@staticmethod
|
||||
def _append_mask_token(sequence, mask_token_id):
|
||||
""" Append a [MASK] token at the end of the sequence that the MLM model
|
||||
is going to try to predict.
|
||||
"""
|
||||
if mask_token_id is not None:
|
||||
tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long)
|
||||
return torch.cat((sequence, tokens_to_append), dim=1)
|
||||
|
||||
return sequence
|
||||
|
||||
@staticmethod
|
||||
def _create_language_embeddings(sequence, language):
|
||||
if language is not None:
|
||||
return torch.tensor([language] * sequence.shape[1]).view(1, -1)
|
||||
return None
|
||||
|
||||
|
||||
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
|
@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_loss
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device)
|
||||
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||
|
||||
# Build permutation mask so that previous tokens don't see last token
|
||||
perm_mask = torch.zeros(
|
||||
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
|
||||
dtype=torch.float, device=input_ids.device
|
||||
)
|
||||
perm_mask[:, :, -1] = 1.0
|
||||
|
||||
# We'll only predict the last token
|
||||
target_mapping = torch.zeros(
|
||||
(input_ids.shape[0], 1, input_ids.shape[1]),
|
||||
dtype=torch.float, device=input_ids.device
|
||||
)
|
||||
target_mapping[0, 0, -1] = 1.0
|
||||
|
||||
return {"input_ids": input_ids,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping
|
||||
}
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
transformer_outputs = self.transformer(input_ids,
|
||||
@ -972,40 +996,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
|
||||
input_ids = self._add_dummy_token(input_ids)
|
||||
perm_mask = self._create_perm_mask(input_ids)
|
||||
target_mapping = self._create_target_mapping(input_ids)
|
||||
arguments = {
|
||||
"input_ids": input_ids,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping,
|
||||
}
|
||||
return arguments
|
||||
|
||||
@staticmethod
|
||||
def _add_dummy_token(sequence):
|
||||
dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long)
|
||||
return torch.cat((sequence, dummy), dim=1)
|
||||
|
||||
@staticmethod
|
||||
def _create_perm_mask(sequence):
|
||||
mask = torch.zeros(
|
||||
(sequence.shape[0], sequence.shape[1], sequence.shape[1]),
|
||||
dtype=torch.float,
|
||||
)
|
||||
mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _create_target_mapping(sequence):
|
||||
target_mapping = torch.zeros(
|
||||
(sequence.shape[0], 1, sequence.shape[1]),
|
||||
dtype=torch.float,
|
||||
)
|
||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||
return target_mapping
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
|
@ -761,7 +761,7 @@ class PreTrainedTokenizer(object):
|
||||
padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
|
||||
The tokenizer padding sides are handled by the following strings:
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
Defaults to False: no padding.
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
or PyTorch torch.Tensor instead of a list of python integers.
|
||||
|
Loading…
Reference in New Issue
Block a user