refactoring generation

This commit is contained in:
thomwolf 2019-12-16 22:22:30 +01:00
parent 07bc8efbc3
commit a468870fd2
2 changed files with 213 additions and 227 deletions

View File

@ -57,8 +57,19 @@ class PretrainedConfig(object):
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_decoder = kwargs.pop('is_decoder', False)
# Parameters for sequence generation
self.generate_length = kwargs.pop('generate_length', 10)
self.generate_do_sample = kwargs.pop('generate_do_sample', False)
self.generate_num_beams = kwargs.pop('generate_num_beams', 1)
self.generate_temperature = kwargs.pop('generate_temperature', 1.0)
self.generate_top_k = kwargs.pop('generate_top_k', 50)
self.generate_top_p = kwargs.pop('generate_top_p', 0.0)
self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0)
def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.

View File

@ -82,6 +82,7 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
# Save config in model
self.config = config
@ -89,93 +90,6 @@ class PreTrainedModel(nn.Module):
def base_model(self):
return getattr(self, self.base_model_prefix, self)
def decode(self,
prompt_ids=None,
device=torch.device('cpu'),
length=10,
do_sample=False,
temperature=1.,
k=9,
p=0,
repetition_penalty=1,
**model_kwargs):
""" Generic sequence generator for single-stack models with a LM head.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
The sequence to encode.
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**device**: (`optional`) `torch.device`
The device on which the prompt_ids will be initialized if not provided.
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if prompt_ids is None:
prompt_ids = torch.tensor([[]], dtype=torch.long, device=device)
# When the model does not have a LM head `get_output_embeddings`
# returns `None`. We use this mechanism to determine whether we
# should proceed with decoding or not.
if self.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
# The followings checks that the model is on the same device as the one
# that is specified. It only works for models that fit on one GPU.
model_device = next(self.parameters()).device
if model_device != prompt_ids.device:
warnings.warn(
"The model is not on the same device as the prompts. Expected {}, got {}.".format(
prompt_ids.device, model_device
)
)
sampler_config = {
"k": k,
"p": p,
"do_sample": do_sample,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
return self._greedy_decode_or_sample(prompt_ids, length, sampler_config, **model_kwargs)
def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **model_kwargs):
""" Generate text using greedy decoding or by sampling tokens."""
sampler = Sampler(**sampler_config)
generated_sequence = prompt_ids
with torch.no_grad():
for _ in trange(length):
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
outputs = self(**arguments)
next_tokens_logits = outputs[0][:, -1, :]
next_tokens = sampler.get_one_token(
next_tokens_logits, generated_sequence
)
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
return generated_sequence.squeeze(0)
def _prepare_inputs_for_decoding(self, input_ids, **kwargs):
arguments = {"input_ids": input_ids}
arguments.update(kwargs)
return arguments
def get_input_embeddings(self):
""" Get model's input embeddings
"""
@ -306,6 +220,9 @@ class PreTrainedModel(nn.Module):
# Tie weights if needed
self.tie_weights()
# Initialize decoding head if we have output embeddings
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
@ -571,6 +488,204 @@ class PreTrainedModel(nn.Module):
return model
def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
**model_kwargs):
""" Generic sequence generator for single-stack models with a LM head.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if input_ids is None:
input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device)
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
sampler_config = {
"k": k,
"p": p,
"do_sample": do_sample,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
sampler = Sampler(**sampler_config)
generated_sequence = input_ids
for _ in trange(length):
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
outputs = self(**arguments)
next_tokens_logits = outputs[0][:, -1, :]
next_tokens = sampler.get_one_token(
next_tokens_logits, generated_sequence
)
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
return generated_sequence.squeeze(0)
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
return model_kwargs.update({"input_ids": input_ids})
class Sampler(object):
r""" Sampler is used to generate sequences of ids from logit inputs.
Greedy decoding, which consists in chosing the most probable token at each
step, is the default behaviour. Sampling with varying temperature, top_k
and nucleus filtering is also implemented.
Attributes:
**device**: ``torch.device``
Device on which the computations will be run.
**do_sample**: bool
Whether to sample or do greedy decoding.
**k**: int between 0 and vocab_size
Parameter for the top-k filtering
**p**: float between 0 and 1
Parameter for the nucleus filtering
**temperature**: strictly positive float
Parameter used to modulate the distribution over ids. Low temperatures
put more emphasis on highly probably token while high temperatures tend
to smooth the probability distribution.
**repetition_penalty**: strictly postitive float
The penalty applied to repeating ids
"""
def __init__(
self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0
):
self.k = k
self.p = p
self.do_sample = do_sample
self.temperature = temperature
self.repetition_penalty = repetition_penalty
self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False
if self.p > 1:
warnings.warn(
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
will be applied. If this is not the behavior you expect, change the value of p.""".format(
self.p
)
)
def get_one_token(self, next_token_logits, past_sequence):
logits = self.apply_repetition_penalty(next_token_logits, past_sequence)
if self.do_sample:
logits = self.apply_temperature(logits)
logits = self.apply_top_k_filter(logits)
logits = self.apply_nucleus_filter(logits)
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return torch.argmax(logits, dim=-1).unsqueeze(-1)
def apply_repetition_penalty(self, logits, past_sequence):
""" Apply a penalty to tokens that appear more than once in the
generated sequence.
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
language model for controllable generation." arXiv preprint
arXiv:1909.05858 (2019).
"""
if self.do_apply_repetition_penalty:
generated_token_idx = set(past_sequence[0].tolist())
for token_idx in generated_token_idx:
logits[0, token_idx] /= self.repetition_penalty
return logits
def apply_temperature(self, logits):
""" Shape the tokens' distribution through temperature. The higher the value
of the temperature, the more skewed towards high probability events the
distribution is.
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
MIT press, 2016.
"""
# when dividing a float by 0, torch returns inf which in turns breaks the
# multinomial with an error message that is not very helpful. It is better
# for the user to break the execution and explain why.
if self.temperature == 0:
raise ZeroDivisionError(
"""You are trying to sample with a temperature equal to 0.
If you wanted to do greedy sampling, set instead `do_sample` to False.
Otherwise set the temperature to a value different from 0."""
)
return logits / self.temperature
def apply_top_k_filter(self, logits):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically we select the set of size k such that
the sum of its items' probabilities is maximum.
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
story generation." arXiv preprint arXiv:1805.04833 (2018).
"""
if self.k > 0:
vocabulary_size = logits.size(-1)
if self.k > vocabulary_size:
warnings.warn(
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
We adjusted k's value to the vocabulary size; if that was what you intended to do
we recommend setting k to 0 instead. It this is not the behavior you expected,
choose a value of k that is smaller than the vocabulary size.""".format(
self.k, vocabulary_size
)
)
self.k = vocabulary_size
indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None]
logits[indices_to_remove] = -float("Inf")
return logits
def apply_nucleus_filter(self, logits):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically, choose the smallest set such that the
sum of its items' probabilities is greater than a number p in [0,1].
.. Holtzman, Ari, et al. "The curious case of neural text
degeneration." arXiv preprint arXiv:1904.09751 (2019).
"""
if self.p > 0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)
# Remove tokens with cumulative probability above the threshold,
# but keep the first token above the threshold.
sorted_indices_to_remove = cumulative_probabilities > self.p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = -float("Inf")
return logits
class Conv1D(nn.Module):
def __init__(self, nf, nx):
@ -948,143 +1063,3 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
class Sampler(object):
r""" Sampler is used to generate sequences of ids from logit inputs.
Greedy decoding, which consists in chosing the most probable token at each
step, is the default behaviour. Sampling with varying temperature, top_k
and nucleus filtering is also implemented.
Attributes:
**device**: ``torch.device``
Device on which the computations will be run.
**do_sample**: bool
Whether to sample or do greedy decoding.
**k**: int between 0 and vocab_size
Parameter for the top-k filtering
**p**: float between 0 and 1
Parameter for the nucleus filtering
**temperature**: strictly positive float
Parameter used to modulate the distribution over ids. Low temperatures
put more emphasis on highly probably token while high temperatures tend
to smooth the probability distribution.
**repetition_penalty**: strictly postitive float
The penalty applied to repeating ids
"""
def __init__(
self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0
):
self.k = k
self.p = p
self.do_sample = do_sample
self.temperature = temperature
self.repetition_penalty = repetition_penalty
self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False
if self.p > 1:
warnings.warn(
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
will be applied. If this is not the behavior you expect, change the value of p.""".format(
self.p
)
)
def get_one_token(self, next_token_logits, past_sequence):
logits = self.apply_repetition_penalty(next_token_logits, past_sequence)
if self.do_sample:
logits = self.apply_temperature(logits)
logits = self.apply_top_k_filter(logits)
logits = self.apply_nucleus_filter(logits)
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return torch.argmax(logits, dim=-1).unsqueeze(-1)
def apply_repetition_penalty(self, logits, past_sequence):
""" Apply a penalty to tokens that appear more than once in the
generated sequence.
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
language model for controllable generation." arXiv preprint
arXiv:1909.05858 (2019).
"""
if self.do_apply_repetition_penalty:
generated_token_idx = set(past_sequence[0].tolist())
for token_idx in generated_token_idx:
logits[0, token_idx] /= self.repetition_penalty
return logits
def apply_temperature(self, logits):
""" Shape the tokens' distribution through temperature. The higher the value
of the temperature, the more skewed towards high probability events the
distribution is.
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
MIT press, 2016.
"""
# when dividing a float by 0, torch returns inf which in turns breaks the
# multinomial with an error message that is not very helpful. It is better
# for the user to break the execution and explain why.
if self.temperature == 0:
raise ZeroDivisionError(
"""You are trying to sample with a temperature equal to 0.
If you wanted to do greedy sampling, set instead `do_sample` to False.
Otherwise set the temperature to a value different from 0."""
)
return logits / self.temperature
def apply_top_k_filter(self, logits):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically we select the set of size k such that
the sum of its items' probabilities is maximum.
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
story generation." arXiv preprint arXiv:1805.04833 (2018).
"""
if self.k > 0:
vocabulary_size = logits.size(-1)
if self.k > vocabulary_size:
warnings.warn(
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
We adjusted k's value to the vocabulary size; if that was what you intended to do
we recommend setting k to 0 instead. It this is not the behavior you expected,
choose a value of k that is smaller than the vocabulary size.""".format(
self.k, vocabulary_size
)
)
self.k = vocabulary_size
indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None]
logits[indices_to_remove] = -float("Inf")
return logits
def apply_nucleus_filter(self, logits):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically, choose the smallest set such that the
sum of its items' probabilities is greater than a number p in [0,1].
.. Holtzman, Ari, et al. "The curious case of neural text
degeneration." arXiv preprint arXiv:1904.09751 (2019).
"""
if self.p > 0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)
# Remove tokens with cumulative probability above the threshold,
# but keep the first token above the threshold.
sorted_indices_to_remove = cumulative_probabilities > self.p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = -float("Inf")
return logits