mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
further cleanup
This commit is contained in:
parent
8e5587fb79
commit
3d2096f516
@ -91,7 +91,7 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||||
kwargs = {"language": None, "mask_token": None}
|
kwargs = {"language": None, "mask_token_id": None}
|
||||||
|
|
||||||
# Set the language
|
# Set the language
|
||||||
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
||||||
@ -112,7 +112,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
|||||||
# XLM masked-language modeling (MLM) models need masked token
|
# XLM masked-language modeling (MLM) models need masked token
|
||||||
is_xlm_mlm = "mlm" in args.model_name_or_path
|
is_xlm_mlm = "mlm" in args.model_name_or_path
|
||||||
if is_xlm_mlm:
|
if is_xlm_mlm:
|
||||||
kwargs["mask_token"] = tokenizer.mask_token_id
|
kwargs["mask_token_id"] = tokenizer.mask_token_id
|
||||||
|
|
||||||
return prompt_text, kwargs
|
return prompt_text, kwargs
|
||||||
|
|
||||||
@ -204,14 +204,13 @@ def main():
|
|||||||
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
||||||
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
|
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
|
||||||
|
|
||||||
output_sequences = model.decode(
|
output_sequences = model.generate(
|
||||||
prompt_ids=encoded_prompt,
|
intput_ids=encoded_prompt,
|
||||||
length=args.length,
|
length=args.length,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
k=args.k,
|
top_k=args.k,
|
||||||
p=args.p,
|
top_p=args.p,
|
||||||
repetition_penalty=args.repetition_penalty,
|
repetition_penalty=args.repetition_penalty,
|
||||||
device=args.device,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -113,6 +113,8 @@ class XLMConfig(PretrainedConfig):
|
|||||||
summary_first_dropout=0.1,
|
summary_first_dropout=0.1,
|
||||||
start_n_top=5,
|
start_n_top=5,
|
||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
|
mask_token_id = 0,
|
||||||
|
lang_id = 0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Constructs XLMConfig.
|
"""Constructs XLMConfig.
|
||||||
"""
|
"""
|
||||||
@ -156,6 +158,8 @@ class XLMConfig(PretrainedConfig):
|
|||||||
self.summary_first_dropout = summary_first_dropout
|
self.summary_first_dropout = summary_first_dropout
|
||||||
self.start_n_top = start_n_top
|
self.start_n_top = start_n_top
|
||||||
self.end_n_top = end_n_top
|
self.end_n_top = end_n_top
|
||||||
|
self.mask_token_id = mask_token_id
|
||||||
|
self.lang_id = lang_id
|
||||||
else:
|
else:
|
||||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||||
" or the path to a pretrained model config file (str)")
|
" or the path to a pretrained model config file (str)")
|
||||||
|
@ -488,7 +488,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
||||||
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
||||||
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
|
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
|
||||||
length_penalty=None, num_return_sequences=None, **kwargs):
|
length_penalty=None, num_return_sequences=None, **model_kwargs):
|
||||||
""" Sequence generator for models with a LM head.
|
""" Sequence generator for models with a LM head.
|
||||||
|
|
||||||
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||||
@ -575,11 +575,13 @@ class PreTrainedModel(nn.Module):
|
|||||||
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
|
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||||
temperature, top_k, top_p, repetition_penalty,
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
pad_token_id, eos_token_ids, effective_batch_size,
|
pad_token_id, eos_token_ids, effective_batch_size,
|
||||||
length_penalty, num_beams, vocab_size)
|
length_penalty, num_beams, vocab_size,
|
||||||
|
**model_kwargs)
|
||||||
else:
|
else:
|
||||||
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
|
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||||
temperature, top_k, top_p, repetition_penalty,
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
pad_token_id, eos_token_ids, effective_batch_size)
|
pad_token_id, eos_token_ids, effective_batch_size,
|
||||||
|
**model_kwargs)
|
||||||
|
|
||||||
if num_return_sequences != 1:
|
if num_return_sequences != 1:
|
||||||
output = output.view(batch_size, num_return_sequences, -1)
|
output = output.view(batch_size, num_return_sequences, -1)
|
||||||
@ -587,7 +589,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||||
temperature, top_k, top_p, repetition_penalty,
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
pad_token_id, eos_token_ids, batch_size):
|
pad_token_id, eos_token_ids, batch_size,
|
||||||
|
**model_kwargs):
|
||||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||||
All returned sequence are generated independantly.
|
All returned sequence are generated independantly.
|
||||||
"""
|
"""
|
||||||
@ -598,7 +601,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
pasts = None
|
pasts = None
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
|
|
||||||
@ -640,7 +643,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||||
temperature, top_k, top_p, repetition_penalty,
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
pad_token_id, eos_token_ids, batch_size,
|
pad_token_id, eos_token_ids, batch_size,
|
||||||
length_penalty, num_beams, vocab_size):
|
length_penalty, num_beams, vocab_size,
|
||||||
|
**model_kwargs):
|
||||||
""" Generate sequences for each example with beam search.
|
""" Generate sequences for each example with beam search.
|
||||||
"""
|
"""
|
||||||
# Expand input to num beams
|
# Expand input to num beams
|
||||||
@ -662,7 +666,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
|
||||||
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
|
@ -639,6 +639,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.pred_layer.proj
|
return self.pred_layer.proj
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
||||||
|
mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id
|
||||||
|
lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id
|
||||||
|
|
||||||
|
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
|
||||||
|
input_ids = torch.cat([input_ids, mask_token], dim=1)
|
||||||
|
if lang_id is not None:
|
||||||
|
langs = torch.full_like(input_ids, lang_id)
|
||||||
|
else:
|
||||||
|
langs = None
|
||||||
|
return {"input_ids": input_ids, "langs": langs}
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
|
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
@ -657,33 +669,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
|
|
||||||
mask_token = model_kwargs.pop("mask_token", None)
|
|
||||||
language = model_kwargs.pop("language", None)
|
|
||||||
input_ids = self._append_mask_token(input_ids, mask_token)
|
|
||||||
langs = self._create_language_embeddings(input_ids, language)
|
|
||||||
arguments = {"input_ids": input_ids, "langs": langs}
|
|
||||||
arguments.update(model_kwargs)
|
|
||||||
|
|
||||||
return arguments
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _append_mask_token(sequence, mask_token_id):
|
|
||||||
""" Append a [MASK] token at the end of the sequence that the MLM model
|
|
||||||
is going to try to predict.
|
|
||||||
"""
|
|
||||||
if mask_token_id is not None:
|
|
||||||
tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long)
|
|
||||||
return torch.cat((sequence, tokens_to_append), dim=1)
|
|
||||||
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_language_embeddings(sequence, language):
|
|
||||||
if language is not None:
|
|
||||||
return torch.tensor([language] * sequence.shape[1]).view(1, -1)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
|
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
|
||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
|
@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_loss
|
return self.lm_loss
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
||||||
|
# Add dummy token at the end (no attention on this one)
|
||||||
|
dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device)
|
||||||
|
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||||
|
|
||||||
|
# Build permutation mask so that previous tokens don't see last token
|
||||||
|
perm_mask = torch.zeros(
|
||||||
|
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
|
||||||
|
dtype=torch.float, device=input_ids.device
|
||||||
|
)
|
||||||
|
perm_mask[:, :, -1] = 1.0
|
||||||
|
|
||||||
|
# We'll only predict the last token
|
||||||
|
target_mapping = torch.zeros(
|
||||||
|
(input_ids.shape[0], 1, input_ids.shape[1]),
|
||||||
|
dtype=torch.float, device=input_ids.device
|
||||||
|
)
|
||||||
|
target_mapping[0, 0, -1] = 1.0
|
||||||
|
|
||||||
|
return {"input_ids": input_ids,
|
||||||
|
"perm_mask": perm_mask,
|
||||||
|
"target_mapping": target_mapping
|
||||||
|
}
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
@ -972,40 +996,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||||
|
|
||||||
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
|
|
||||||
input_ids = self._add_dummy_token(input_ids)
|
|
||||||
perm_mask = self._create_perm_mask(input_ids)
|
|
||||||
target_mapping = self._create_target_mapping(input_ids)
|
|
||||||
arguments = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"perm_mask": perm_mask,
|
|
||||||
"target_mapping": target_mapping,
|
|
||||||
}
|
|
||||||
return arguments
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _add_dummy_token(sequence):
|
|
||||||
dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long)
|
|
||||||
return torch.cat((sequence, dummy), dim=1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_perm_mask(sequence):
|
|
||||||
mask = torch.zeros(
|
|
||||||
(sequence.shape[0], sequence.shape[1], sequence.shape[1]),
|
|
||||||
dtype=torch.float,
|
|
||||||
)
|
|
||||||
mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
|
||||||
return mask
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_target_mapping(sequence):
|
|
||||||
target_mapping = torch.zeros(
|
|
||||||
(sequence.shape[0], 1, sequence.shape[1]),
|
|
||||||
dtype=torch.float,
|
|
||||||
)
|
|
||||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
|
||||||
return target_mapping
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
|
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
|
||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
|
@ -761,7 +761,7 @@ class PreTrainedTokenizer(object):
|
|||||||
padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
|
padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
|
||||||
The tokenizer padding sides are handled by the following strings:
|
The tokenizer padding sides are handled by the following strings:
|
||||||
- 'left': pads on the left of the sequences
|
- 'left': pads on the left of the sequences
|
||||||
- 'right': pads on the right of the sequences
|
- 'right': pads on the right of the sequences
|
||||||
Defaults to False: no padding.
|
Defaults to False: no padding.
|
||||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||||
or PyTorch torch.Tensor instead of a list of python integers.
|
or PyTorch torch.Tensor instead of a list of python integers.
|
||||||
|
Loading…
Reference in New Issue
Block a user