mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
refactoring generation
This commit is contained in:
parent
07bc8efbc3
commit
a468870fd2
@ -57,8 +57,19 @@ class PretrainedConfig(object):
|
||||
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
|
||||
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
|
||||
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
||||
|
||||
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||
self.is_decoder = kwargs.pop('is_decoder', False)
|
||||
|
||||
# Parameters for sequence generation
|
||||
self.generate_length = kwargs.pop('generate_length', 10)
|
||||
self.generate_do_sample = kwargs.pop('generate_do_sample', False)
|
||||
self.generate_num_beams = kwargs.pop('generate_num_beams', 1)
|
||||
self.generate_temperature = kwargs.pop('generate_temperature', 1.0)
|
||||
self.generate_top_k = kwargs.pop('generate_top_k', 50)
|
||||
self.generate_top_p = kwargs.pop('generate_top_p', 0.0)
|
||||
self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a configuration object to the directory `save_directory`, so that it
|
||||
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
|
||||
|
@ -82,6 +82,7 @@ class PreTrainedModel(nn.Module):
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
@ -89,93 +90,6 @@ class PreTrainedModel(nn.Module):
|
||||
def base_model(self):
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
def decode(self,
|
||||
prompt_ids=None,
|
||||
device=torch.device('cpu'),
|
||||
length=10,
|
||||
do_sample=False,
|
||||
temperature=1.,
|
||||
k=9,
|
||||
p=0,
|
||||
repetition_penalty=1,
|
||||
**model_kwargs):
|
||||
""" Generic sequence generator for single-stack models with a LM head.
|
||||
|
||||
The method currently supports greedy decoding and sampling. See the
|
||||
documentation of the `Sampler` class for more information about the
|
||||
parameters related to sampling.
|
||||
|
||||
Params:
|
||||
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
|
||||
The sequence to encode.
|
||||
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape (1,)
|
||||
**device**: (`optional`) `torch.device`
|
||||
The device on which the prompt_ids will be initialized if not provided.
|
||||
**length**: (`optional`) int
|
||||
The length of the sequence to be generated.
|
||||
**do_sample**: (`optional`) bool
|
||||
If set to `False` we use greedy decoding; otherwise sampling.
|
||||
**temperature**: (`optional`) float
|
||||
The value used to module the next token probabilities.
|
||||
**k**: (`optional`) int
|
||||
The parameter used for k-filtering.
|
||||
**p**: (`optional`) float
|
||||
The parameter for nucleus sampling. Must be between 0 and 1.
|
||||
**repetition_penalty**: (`optional`) float
|
||||
The parameter for repetition penalty.
|
||||
"""
|
||||
|
||||
if prompt_ids is None:
|
||||
prompt_ids = torch.tensor([[]], dtype=torch.long, device=device)
|
||||
|
||||
# When the model does not have a LM head `get_output_embeddings`
|
||||
# returns `None`. We use this mechanism to determine whether we
|
||||
# should proceed with decoding or not.
|
||||
if self.get_output_embeddings() is None:
|
||||
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
|
||||
|
||||
# The followings checks that the model is on the same device as the one
|
||||
# that is specified. It only works for models that fit on one GPU.
|
||||
model_device = next(self.parameters()).device
|
||||
if model_device != prompt_ids.device:
|
||||
warnings.warn(
|
||||
"The model is not on the same device as the prompts. Expected {}, got {}.".format(
|
||||
prompt_ids.device, model_device
|
||||
)
|
||||
)
|
||||
|
||||
sampler_config = {
|
||||
"k": k,
|
||||
"p": p,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
}
|
||||
return self._greedy_decode_or_sample(prompt_ids, length, sampler_config, **model_kwargs)
|
||||
|
||||
def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **model_kwargs):
|
||||
""" Generate text using greedy decoding or by sampling tokens."""
|
||||
sampler = Sampler(**sampler_config)
|
||||
generated_sequence = prompt_ids
|
||||
with torch.no_grad():
|
||||
for _ in trange(length):
|
||||
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
|
||||
outputs = self(**arguments)
|
||||
next_tokens_logits = outputs[0][:, -1, :]
|
||||
next_tokens = sampler.get_one_token(
|
||||
next_tokens_logits, generated_sequence
|
||||
)
|
||||
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
|
||||
|
||||
return generated_sequence.squeeze(0)
|
||||
|
||||
def _prepare_inputs_for_decoding(self, input_ids, **kwargs):
|
||||
arguments = {"input_ids": input_ids}
|
||||
arguments.update(kwargs)
|
||||
return arguments
|
||||
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
@ -306,6 +220,9 @@ class PreTrainedModel(nn.Module):
|
||||
# Tie weights if needed
|
||||
self.tie_weights()
|
||||
|
||||
# Initialize decoding head if we have output embeddings
|
||||
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
|
||||
@ -571,6 +488,204 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
return model
|
||||
|
||||
def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None,
|
||||
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
||||
**model_kwargs):
|
||||
""" Generic sequence generator for single-stack models with a LM head.
|
||||
|
||||
The method currently supports greedy decoding and sampling. See the
|
||||
documentation of the `Sampler` class for more information about the
|
||||
parameters related to sampling.
|
||||
|
||||
Params:
|
||||
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape (1,)
|
||||
**length**: (`optional`) int
|
||||
The length of the sequence to be generated.
|
||||
**do_sample**: (`optional`) bool
|
||||
If set to `False` we use greedy decoding; otherwise sampling.
|
||||
**temperature**: (`optional`) float
|
||||
The value used to module the next token probabilities.
|
||||
**k**: (`optional`) int
|
||||
The parameter used for k-filtering.
|
||||
**p**: (`optional`) float
|
||||
The parameter for nucleus sampling. Must be between 0 and 1.
|
||||
**repetition_penalty**: (`optional`) float
|
||||
The parameter for repetition penalty.
|
||||
"""
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device)
|
||||
|
||||
# We cannot generate if the model does not have a LM head
|
||||
if self.get_output_embeddings() is None:
|
||||
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
|
||||
|
||||
sampler_config = {
|
||||
"k": k,
|
||||
"p": p,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
}
|
||||
|
||||
sampler = Sampler(**sampler_config)
|
||||
generated_sequence = input_ids
|
||||
for _ in trange(length):
|
||||
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
|
||||
outputs = self(**arguments)
|
||||
next_tokens_logits = outputs[0][:, -1, :]
|
||||
next_tokens = sampler.get_one_token(
|
||||
next_tokens_logits, generated_sequence
|
||||
)
|
||||
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
|
||||
|
||||
return generated_sequence.squeeze(0)
|
||||
|
||||
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
|
||||
return model_kwargs.update({"input_ids": input_ids})
|
||||
|
||||
|
||||
class Sampler(object):
|
||||
r""" Sampler is used to generate sequences of ids from logit inputs.
|
||||
|
||||
Greedy decoding, which consists in chosing the most probable token at each
|
||||
step, is the default behaviour. Sampling with varying temperature, top_k
|
||||
and nucleus filtering is also implemented.
|
||||
|
||||
Attributes:
|
||||
**device**: ``torch.device``
|
||||
Device on which the computations will be run.
|
||||
**do_sample**: bool
|
||||
Whether to sample or do greedy decoding.
|
||||
**k**: int between 0 and vocab_size
|
||||
Parameter for the top-k filtering
|
||||
**p**: float between 0 and 1
|
||||
Parameter for the nucleus filtering
|
||||
**temperature**: strictly positive float
|
||||
Parameter used to modulate the distribution over ids. Low temperatures
|
||||
put more emphasis on highly probably token while high temperatures tend
|
||||
to smooth the probability distribution.
|
||||
**repetition_penalty**: strictly postitive float
|
||||
The penalty applied to repeating ids
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0
|
||||
):
|
||||
self.k = k
|
||||
self.p = p
|
||||
self.do_sample = do_sample
|
||||
self.temperature = temperature
|
||||
self.repetition_penalty = repetition_penalty
|
||||
|
||||
self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False
|
||||
|
||||
if self.p > 1:
|
||||
warnings.warn(
|
||||
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
|
||||
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
|
||||
will be applied. If this is not the behavior you expect, change the value of p.""".format(
|
||||
self.p
|
||||
)
|
||||
)
|
||||
|
||||
def get_one_token(self, next_token_logits, past_sequence):
|
||||
logits = self.apply_repetition_penalty(next_token_logits, past_sequence)
|
||||
if self.do_sample:
|
||||
logits = self.apply_temperature(logits)
|
||||
logits = self.apply_top_k_filter(logits)
|
||||
logits = self.apply_nucleus_filter(logits)
|
||||
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
return torch.argmax(logits, dim=-1).unsqueeze(-1)
|
||||
|
||||
def apply_repetition_penalty(self, logits, past_sequence):
|
||||
""" Apply a penalty to tokens that appear more than once in the
|
||||
generated sequence.
|
||||
|
||||
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
|
||||
language model for controllable generation." arXiv preprint
|
||||
arXiv:1909.05858 (2019).
|
||||
"""
|
||||
if self.do_apply_repetition_penalty:
|
||||
generated_token_idx = set(past_sequence[0].tolist())
|
||||
for token_idx in generated_token_idx:
|
||||
logits[0, token_idx] /= self.repetition_penalty
|
||||
return logits
|
||||
|
||||
def apply_temperature(self, logits):
|
||||
""" Shape the tokens' distribution through temperature. The higher the value
|
||||
of the temperature, the more skewed towards high probability events the
|
||||
distribution is.
|
||||
|
||||
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
|
||||
MIT press, 2016.
|
||||
"""
|
||||
# when dividing a float by 0, torch returns inf which in turns breaks the
|
||||
# multinomial with an error message that is not very helpful. It is better
|
||||
# for the user to break the execution and explain why.
|
||||
if self.temperature == 0:
|
||||
raise ZeroDivisionError(
|
||||
"""You are trying to sample with a temperature equal to 0.
|
||||
If you wanted to do greedy sampling, set instead `do_sample` to False.
|
||||
Otherwise set the temperature to a value different from 0."""
|
||||
)
|
||||
return logits / self.temperature
|
||||
|
||||
def apply_top_k_filter(self, logits):
|
||||
""" Use the probability distribution of the tokens to determine the set
|
||||
to be sampled from. Specifically we select the set of size k such that
|
||||
the sum of its items' probabilities is maximum.
|
||||
|
||||
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
|
||||
story generation." arXiv preprint arXiv:1805.04833 (2018).
|
||||
"""
|
||||
if self.k > 0:
|
||||
vocabulary_size = logits.size(-1)
|
||||
if self.k > vocabulary_size:
|
||||
warnings.warn(
|
||||
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
|
||||
We adjusted k's value to the vocabulary size; if that was what you intended to do
|
||||
we recommend setting k to 0 instead. It this is not the behavior you expected,
|
||||
choose a value of k that is smaller than the vocabulary size.""".format(
|
||||
self.k, vocabulary_size
|
||||
)
|
||||
)
|
||||
self.k = vocabulary_size
|
||||
|
||||
indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
|
||||
return logits
|
||||
|
||||
def apply_nucleus_filter(self, logits):
|
||||
""" Use the probability distribution of the tokens to determine the set
|
||||
to be sampled from. Specifically, choose the smallest set such that the
|
||||
sum of its items' probabilities is greater than a number p in [0,1].
|
||||
|
||||
.. Holtzman, Ari, et al. "The curious case of neural text
|
||||
degeneration." arXiv preprint arXiv:1904.09751 (2019).
|
||||
"""
|
||||
if self.p > 0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold,
|
||||
# but keep the first token above the threshold.
|
||||
sorted_indices_to_remove = cumulative_probabilities > self.p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, nx):
|
||||
@ -948,143 +1063,3 @@ def prune_layer(layer, index, dim=None):
|
||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||
else:
|
||||
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
|
||||
|
||||
|
||||
class Sampler(object):
|
||||
r""" Sampler is used to generate sequences of ids from logit inputs.
|
||||
|
||||
Greedy decoding, which consists in chosing the most probable token at each
|
||||
step, is the default behaviour. Sampling with varying temperature, top_k
|
||||
and nucleus filtering is also implemented.
|
||||
|
||||
Attributes:
|
||||
**device**: ``torch.device``
|
||||
Device on which the computations will be run.
|
||||
**do_sample**: bool
|
||||
Whether to sample or do greedy decoding.
|
||||
**k**: int between 0 and vocab_size
|
||||
Parameter for the top-k filtering
|
||||
**p**: float between 0 and 1
|
||||
Parameter for the nucleus filtering
|
||||
**temperature**: strictly positive float
|
||||
Parameter used to modulate the distribution over ids. Low temperatures
|
||||
put more emphasis on highly probably token while high temperatures tend
|
||||
to smooth the probability distribution.
|
||||
**repetition_penalty**: strictly postitive float
|
||||
The penalty applied to repeating ids
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0
|
||||
):
|
||||
self.k = k
|
||||
self.p = p
|
||||
self.do_sample = do_sample
|
||||
self.temperature = temperature
|
||||
self.repetition_penalty = repetition_penalty
|
||||
|
||||
self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False
|
||||
|
||||
if self.p > 1:
|
||||
warnings.warn(
|
||||
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
|
||||
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
|
||||
will be applied. If this is not the behavior you expect, change the value of p.""".format(
|
||||
self.p
|
||||
)
|
||||
)
|
||||
|
||||
def get_one_token(self, next_token_logits, past_sequence):
|
||||
logits = self.apply_repetition_penalty(next_token_logits, past_sequence)
|
||||
if self.do_sample:
|
||||
logits = self.apply_temperature(logits)
|
||||
logits = self.apply_top_k_filter(logits)
|
||||
logits = self.apply_nucleus_filter(logits)
|
||||
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
return torch.argmax(logits, dim=-1).unsqueeze(-1)
|
||||
|
||||
def apply_repetition_penalty(self, logits, past_sequence):
|
||||
""" Apply a penalty to tokens that appear more than once in the
|
||||
generated sequence.
|
||||
|
||||
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
|
||||
language model for controllable generation." arXiv preprint
|
||||
arXiv:1909.05858 (2019).
|
||||
"""
|
||||
if self.do_apply_repetition_penalty:
|
||||
generated_token_idx = set(past_sequence[0].tolist())
|
||||
for token_idx in generated_token_idx:
|
||||
logits[0, token_idx] /= self.repetition_penalty
|
||||
return logits
|
||||
|
||||
def apply_temperature(self, logits):
|
||||
""" Shape the tokens' distribution through temperature. The higher the value
|
||||
of the temperature, the more skewed towards high probability events the
|
||||
distribution is.
|
||||
|
||||
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
|
||||
MIT press, 2016.
|
||||
"""
|
||||
# when dividing a float by 0, torch returns inf which in turns breaks the
|
||||
# multinomial with an error message that is not very helpful. It is better
|
||||
# for the user to break the execution and explain why.
|
||||
if self.temperature == 0:
|
||||
raise ZeroDivisionError(
|
||||
"""You are trying to sample with a temperature equal to 0.
|
||||
If you wanted to do greedy sampling, set instead `do_sample` to False.
|
||||
Otherwise set the temperature to a value different from 0."""
|
||||
)
|
||||
return logits / self.temperature
|
||||
|
||||
def apply_top_k_filter(self, logits):
|
||||
""" Use the probability distribution of the tokens to determine the set
|
||||
to be sampled from. Specifically we select the set of size k such that
|
||||
the sum of its items' probabilities is maximum.
|
||||
|
||||
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
|
||||
story generation." arXiv preprint arXiv:1805.04833 (2018).
|
||||
"""
|
||||
if self.k > 0:
|
||||
vocabulary_size = logits.size(-1)
|
||||
if self.k > vocabulary_size:
|
||||
warnings.warn(
|
||||
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
|
||||
We adjusted k's value to the vocabulary size; if that was what you intended to do
|
||||
we recommend setting k to 0 instead. It this is not the behavior you expected,
|
||||
choose a value of k that is smaller than the vocabulary size.""".format(
|
||||
self.k, vocabulary_size
|
||||
)
|
||||
)
|
||||
self.k = vocabulary_size
|
||||
|
||||
indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
|
||||
return logits
|
||||
|
||||
def apply_nucleus_filter(self, logits):
|
||||
""" Use the probability distribution of the tokens to determine the set
|
||||
to be sampled from. Specifically, choose the smallest set such that the
|
||||
sum of its items' probabilities is greater than a number p in [0,1].
|
||||
|
||||
.. Holtzman, Ari, et al. "The curious case of neural text
|
||||
degeneration." arXiv preprint arXiv:1904.09751 (2019).
|
||||
"""
|
||||
if self.p > 0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold,
|
||||
# but keep the first token above the threshold.
|
||||
sorted_indices_to_remove = cumulative_probabilities > self.p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
|
||||
return logits
|
||||
|
Loading…
Reference in New Issue
Block a user