diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 08cee75d81b..9c3360892dc 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -57,8 +57,19 @@ class PretrainedConfig(object): self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.pruned_heads = kwargs.pop('pruned_heads', {}) + + # Is decoder is used in encoder-decoder models to differentiate encoder from decoder self.is_decoder = kwargs.pop('is_decoder', False) + # Parameters for sequence generation + self.generate_length = kwargs.pop('generate_length', 10) + self.generate_do_sample = kwargs.pop('generate_do_sample', False) + self.generate_num_beams = kwargs.pop('generate_num_beams', 1) + self.generate_temperature = kwargs.pop('generate_temperature', 1.0) + self.generate_top_k = kwargs.pop('generate_top_k', 50) + self.generate_top_p = kwargs.pop('generate_top_p', 0.0) + self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0) + def save_pretrained(self, save_directory): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 74038351fd1..27d42c552a0 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -82,6 +82,7 @@ class PreTrainedModel(nn.Module): "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) + # Save config in model self.config = config @@ -89,93 +90,6 @@ class PreTrainedModel(nn.Module): def base_model(self): return getattr(self, self.base_model_prefix, self) - def decode(self, - prompt_ids=None, - device=torch.device('cpu'), - length=10, - do_sample=False, - temperature=1., - k=9, - p=0, - repetition_penalty=1, - **model_kwargs): - """ Generic sequence generator for single-stack models with a LM head. - - The method currently supports greedy decoding and sampling. See the - documentation of the `Sampler` class for more information about the - parameters related to sampling. - - Params: - **encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length) - The sequence to encode. - **decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) - The sequence used as a prompt for the generation. If `None` the method initializes - it as an empty `torch.LongTensor` of shape (1,) - **device**: (`optional`) `torch.device` - The device on which the prompt_ids will be initialized if not provided. - **length**: (`optional`) int - The length of the sequence to be generated. - **do_sample**: (`optional`) bool - If set to `False` we use greedy decoding; otherwise sampling. - **temperature**: (`optional`) float - The value used to module the next token probabilities. - **k**: (`optional`) int - The parameter used for k-filtering. - **p**: (`optional`) float - The parameter for nucleus sampling. Must be between 0 and 1. - **repetition_penalty**: (`optional`) float - The parameter for repetition penalty. - """ - - if prompt_ids is None: - prompt_ids = torch.tensor([[]], dtype=torch.long, device=device) - - # When the model does not have a LM head `get_output_embeddings` - # returns `None`. We use this mechanism to determine whether we - # should proceed with decoding or not. - if self.get_output_embeddings() is None: - raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") - - # The followings checks that the model is on the same device as the one - # that is specified. It only works for models that fit on one GPU. - model_device = next(self.parameters()).device - if model_device != prompt_ids.device: - warnings.warn( - "The model is not on the same device as the prompts. Expected {}, got {}.".format( - prompt_ids.device, model_device - ) - ) - - sampler_config = { - "k": k, - "p": p, - "do_sample": do_sample, - "temperature": temperature, - "repetition_penalty": repetition_penalty, - } - return self._greedy_decode_or_sample(prompt_ids, length, sampler_config, **model_kwargs) - - def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **model_kwargs): - """ Generate text using greedy decoding or by sampling tokens.""" - sampler = Sampler(**sampler_config) - generated_sequence = prompt_ids - with torch.no_grad(): - for _ in trange(length): - arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs) - outputs = self(**arguments) - next_tokens_logits = outputs[0][:, -1, :] - next_tokens = sampler.get_one_token( - next_tokens_logits, generated_sequence - ) - generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) - - return generated_sequence.squeeze(0) - - def _prepare_inputs_for_decoding(self, input_ids, **kwargs): - arguments = {"input_ids": input_ids} - arguments.update(kwargs) - return arguments - def get_input_embeddings(self): """ Get model's input embeddings """ @@ -306,6 +220,9 @@ class PreTrainedModel(nn.Module): # Tie weights if needed self.tie_weights() + # Initialize decoding head if we have output embeddings + + def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. @@ -571,6 +488,204 @@ class PreTrainedModel(nn.Module): return model + def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None, + temperature=None, top_k=None, top_p=None, repetition_penalty=None, + **model_kwargs): + """ Generic sequence generator for single-stack models with a LM head. + + The method currently supports greedy decoding and sampling. See the + documentation of the `Sampler` class for more information about the + parameters related to sampling. + + Params: + **input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape (1,) + **length**: (`optional`) int + The length of the sequence to be generated. + **do_sample**: (`optional`) bool + If set to `False` we use greedy decoding; otherwise sampling. + **temperature**: (`optional`) float + The value used to module the next token probabilities. + **k**: (`optional`) int + The parameter used for k-filtering. + **p**: (`optional`) float + The parameter for nucleus sampling. Must be between 0 and 1. + **repetition_penalty**: (`optional`) float + The parameter for repetition penalty. + """ + + if input_ids is None: + input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device) + + # We cannot generate if the model does not have a LM head + if self.get_output_embeddings() is None: + raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") + + sampler_config = { + "k": k, + "p": p, + "do_sample": do_sample, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + } + + sampler = Sampler(**sampler_config) + generated_sequence = input_ids + for _ in trange(length): + arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs) + outputs = self(**arguments) + next_tokens_logits = outputs[0][:, -1, :] + next_tokens = sampler.get_one_token( + next_tokens_logits, generated_sequence + ) + generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) + + return generated_sequence.squeeze(0) + + def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): + return model_kwargs.update({"input_ids": input_ids}) + + +class Sampler(object): + r""" Sampler is used to generate sequences of ids from logit inputs. + + Greedy decoding, which consists in chosing the most probable token at each + step, is the default behaviour. Sampling with varying temperature, top_k + and nucleus filtering is also implemented. + + Attributes: + **device**: ``torch.device`` + Device on which the computations will be run. + **do_sample**: bool + Whether to sample or do greedy decoding. + **k**: int between 0 and vocab_size + Parameter for the top-k filtering + **p**: float between 0 and 1 + Parameter for the nucleus filtering + **temperature**: strictly positive float + Parameter used to modulate the distribution over ids. Low temperatures + put more emphasis on highly probably token while high temperatures tend + to smooth the probability distribution. + **repetition_penalty**: strictly postitive float + The penalty applied to repeating ids + """ + + def __init__( + self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0 + ): + self.k = k + self.p = p + self.do_sample = do_sample + self.temperature = temperature + self.repetition_penalty = repetition_penalty + + self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False + + if self.p > 1: + warnings.warn( + """You are trying to apply nucleus filtering with a value of p greater than 1 ({}). + However p is a probability and its value must lie between 0 and 1. In effect, no filtering + will be applied. If this is not the behavior you expect, change the value of p.""".format( + self.p + ) + ) + + def get_one_token(self, next_token_logits, past_sequence): + logits = self.apply_repetition_penalty(next_token_logits, past_sequence) + if self.do_sample: + logits = self.apply_temperature(logits) + logits = self.apply_top_k_filter(logits) + logits = self.apply_nucleus_filter(logits) + return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return torch.argmax(logits, dim=-1).unsqueeze(-1) + + def apply_repetition_penalty(self, logits, past_sequence): + """ Apply a penalty to tokens that appear more than once in the + generated sequence. + + .. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer + language model for controllable generation." arXiv preprint + arXiv:1909.05858 (2019). + """ + if self.do_apply_repetition_penalty: + generated_token_idx = set(past_sequence[0].tolist()) + for token_idx in generated_token_idx: + logits[0, token_idx] /= self.repetition_penalty + return logits + + def apply_temperature(self, logits): + """ Shape the tokens' distribution through temperature. The higher the value + of the temperature, the more skewed towards high probability events the + distribution is. + + .. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. + MIT press, 2016. + """ + # when dividing a float by 0, torch returns inf which in turns breaks the + # multinomial with an error message that is not very helpful. It is better + # for the user to break the execution and explain why. + if self.temperature == 0: + raise ZeroDivisionError( + """You are trying to sample with a temperature equal to 0. + If you wanted to do greedy sampling, set instead `do_sample` to False. + Otherwise set the temperature to a value different from 0.""" + ) + return logits / self.temperature + + def apply_top_k_filter(self, logits): + """ Use the probability distribution of the tokens to determine the set + to be sampled from. Specifically we select the set of size k such that + the sum of its items' probabilities is maximum. + + .. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural + story generation." arXiv preprint arXiv:1805.04833 (2018). + """ + if self.k > 0: + vocabulary_size = logits.size(-1) + if self.k > vocabulary_size: + warnings.warn( + """You provided a value for k ({}) that is larger than the vocabulary size ({}). + We adjusted k's value to the vocabulary size; if that was what you intended to do + we recommend setting k to 0 instead. It this is not the behavior you expected, + choose a value of k that is smaller than the vocabulary size.""".format( + self.k, vocabulary_size + ) + ) + self.k = vocabulary_size + + indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None] + logits[indices_to_remove] = -float("Inf") + + return logits + + def apply_nucleus_filter(self, logits): + """ Use the probability distribution of the tokens to determine the set + to be sampled from. Specifically, choose the smallest set such that the + sum of its items' probabilities is greater than a number p in [0,1]. + + .. Holtzman, Ari, et al. "The curious case of neural text + degeneration." arXiv preprint arXiv:1904.09751 (2019). + """ + if self.p > 0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + sorted_probabilities = F.softmax(sorted_logits, dim=-1) + cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1) + + # Remove tokens with cumulative probability above the threshold, + # but keep the first token above the threshold. + sorted_indices_to_remove = cumulative_probabilities > self.p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits[indices_to_remove] = -float("Inf") + + return logits + class Conv1D(nn.Module): def __init__(self, nf, nx): @@ -948,143 +1063,3 @@ def prune_layer(layer, index, dim=None): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__)) - - -class Sampler(object): - r""" Sampler is used to generate sequences of ids from logit inputs. - - Greedy decoding, which consists in chosing the most probable token at each - step, is the default behaviour. Sampling with varying temperature, top_k - and nucleus filtering is also implemented. - - Attributes: - **device**: ``torch.device`` - Device on which the computations will be run. - **do_sample**: bool - Whether to sample or do greedy decoding. - **k**: int between 0 and vocab_size - Parameter for the top-k filtering - **p**: float between 0 and 1 - Parameter for the nucleus filtering - **temperature**: strictly positive float - Parameter used to modulate the distribution over ids. Low temperatures - put more emphasis on highly probably token while high temperatures tend - to smooth the probability distribution. - **repetition_penalty**: strictly postitive float - The penalty applied to repeating ids - """ - - def __init__( - self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0 - ): - self.k = k - self.p = p - self.do_sample = do_sample - self.temperature = temperature - self.repetition_penalty = repetition_penalty - - self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False - - if self.p > 1: - warnings.warn( - """You are trying to apply nucleus filtering with a value of p greater than 1 ({}). - However p is a probability and its value must lie between 0 and 1. In effect, no filtering - will be applied. If this is not the behavior you expect, change the value of p.""".format( - self.p - ) - ) - - def get_one_token(self, next_token_logits, past_sequence): - logits = self.apply_repetition_penalty(next_token_logits, past_sequence) - if self.do_sample: - logits = self.apply_temperature(logits) - logits = self.apply_top_k_filter(logits) - logits = self.apply_nucleus_filter(logits) - return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) - return torch.argmax(logits, dim=-1).unsqueeze(-1) - - def apply_repetition_penalty(self, logits, past_sequence): - """ Apply a penalty to tokens that appear more than once in the - generated sequence. - - .. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer - language model for controllable generation." arXiv preprint - arXiv:1909.05858 (2019). - """ - if self.do_apply_repetition_penalty: - generated_token_idx = set(past_sequence[0].tolist()) - for token_idx in generated_token_idx: - logits[0, token_idx] /= self.repetition_penalty - return logits - - def apply_temperature(self, logits): - """ Shape the tokens' distribution through temperature. The higher the value - of the temperature, the more skewed towards high probability events the - distribution is. - - .. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. - MIT press, 2016. - """ - # when dividing a float by 0, torch returns inf which in turns breaks the - # multinomial with an error message that is not very helpful. It is better - # for the user to break the execution and explain why. - if self.temperature == 0: - raise ZeroDivisionError( - """You are trying to sample with a temperature equal to 0. - If you wanted to do greedy sampling, set instead `do_sample` to False. - Otherwise set the temperature to a value different from 0.""" - ) - return logits / self.temperature - - def apply_top_k_filter(self, logits): - """ Use the probability distribution of the tokens to determine the set - to be sampled from. Specifically we select the set of size k such that - the sum of its items' probabilities is maximum. - - .. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural - story generation." arXiv preprint arXiv:1805.04833 (2018). - """ - if self.k > 0: - vocabulary_size = logits.size(-1) - if self.k > vocabulary_size: - warnings.warn( - """You provided a value for k ({}) that is larger than the vocabulary size ({}). - We adjusted k's value to the vocabulary size; if that was what you intended to do - we recommend setting k to 0 instead. It this is not the behavior you expected, - choose a value of k that is smaller than the vocabulary size.""".format( - self.k, vocabulary_size - ) - ) - self.k = vocabulary_size - - indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None] - logits[indices_to_remove] = -float("Inf") - - return logits - - def apply_nucleus_filter(self, logits): - """ Use the probability distribution of the tokens to determine the set - to be sampled from. Specifically, choose the smallest set such that the - sum of its items' probabilities is greater than a number p in [0,1]. - - .. Holtzman, Ari, et al. "The curious case of neural text - degeneration." arXiv preprint arXiv:1904.09751 (2019). - """ - if self.p > 0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - sorted_probabilities = F.softmax(sorted_logits, dim=-1) - cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1) - - # Remove tokens with cumulative probability above the threshold, - # but keep the first token above the threshold. - sorted_indices_to_remove = cumulative_probabilities > self.p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - dim=-1, index=sorted_indices, src=sorted_indices_to_remove - ) - logits[indices_to_remove] = -float("Inf") - - return logits