Generate: store special token tensors under a unique variable name (#31980)

* rename stuff

* english; this one shouldn't be changed

* add a _ to the new var names

* musicgen

* derp
This commit is contained in:
Joao Gante 2024-07-22 14:06:49 +01:00 committed by GitHub
parent aa8f86a421
commit c38c55f4fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 187 additions and 273 deletions

View File

@ -754,12 +754,12 @@ class GenerationMixin:
warpers = LogitsProcessorList()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
elif isinstance(generation_config.eos_token_id, torch.Tensor):
min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
@ -863,31 +863,31 @@ class GenerationMixin:
processors.append(
NoBadWordsLogitsProcessor(
generation_config.bad_words_ids,
generation_config.eos_token_id,
generation_config._eos_token_tensor,
)
)
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config._eos_token_tensor is not None
and generation_config.min_length > 0
):
processors.append(
MinLengthLogitsProcessor(
generation_config.min_length,
generation_config.eos_token_id,
generation_config._eos_token_tensor,
device=device,
)
)
if (
generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None
and generation_config._eos_token_tensor is not None
and generation_config.min_new_tokens > 0
):
processors.append(
MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id,
generation_config._eos_token_tensor,
device=device,
)
)
@ -918,7 +918,7 @@ class GenerationMixin:
processors.append(
ExponentialDecayLengthPenalty(
generation_config.exponential_decay_length_penalty,
generation_config.eos_token_id,
generation_config._eos_token_tensor,
input_ids_seq_length,
)
)
@ -997,8 +997,8 @@ class GenerationMixin:
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
)
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
if generation_config.eos_token_id is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
if generation_config._eos_token_tensor is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria
@ -1349,13 +1349,15 @@ class GenerationMixin:
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs.
Prepares the base generation config, then applies any generation configuration options from kwargs. This
function handles retrocompatibility with respect to configuration files.
"""
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
using_model_generation_config = False
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
@ -1378,6 +1380,7 @@ class GenerationMixin:
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
using_model_generation_config = True
generation_config = self.generation_config
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
@ -1395,6 +1398,16 @@ class GenerationMixin:
else:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
return generation_config, model_kwargs
@ -1493,52 +1506,46 @@ class GenerationMixin:
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
"""
# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
def _tensor_or_none(token_kwargs, token_self, device=None):
if device is None:
device = self.device
token = token_kwargs if token_kwargs is not None else token_self
# Convert special tokens to tensors
def _tensor_or_none(token, device=None):
if token is None:
return token
elif isinstance(token, torch.Tensor):
return token.to(device)
device = device if device is not None else self.device
if isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)
bos_token_id = _tensor_or_none(
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
)
eos_token_id = _tensor_or_none(
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
)
pad_token_id = _tensor_or_none(
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
)
decoder_start_token_id = _tensor_or_none(
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
)
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
if self.config.is_encoder_decoder:
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
decoder_start_token_tensor = (
decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor
)
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_id is not None and eos_token_id.ndim == 0:
eos_token_id = eos_token_id.unsqueeze(0)
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
eos_token_tensor = eos_token_tensor.unsqueeze(0)
# Set pad token if unset (and there are conditions to do so)
if pad_token_id is None and eos_token_id is not None:
if pad_token_tensor is None and eos_token_tensor is not None:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
pad_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if (
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
@ -1547,21 +1554,26 @@ class GenerationMixin:
)
# Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_id is None:
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
if eos_token_tensor is not None and (
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not "
"stop until the maximum length is reached. Depending on other flags, it may even crash."
)
# Update generation config with the updated special tokens tensors
generation_config.bos_token_id = bos_token_id
generation_config.eos_token_id = eos_token_id
generation_config.pad_token_id = pad_token_id
generation_config.decoder_start_token_id = decoder_start_token_id
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
# (in their non-tensor form), in order to enable end-to-end compilation. See
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
generation_config._bos_token_tensor = bos_token_tensor
generation_config._eos_token_tensor = eos_token_tensor
generation_config._pad_token_tensor = pad_token_tensor
generation_config._decoder_start_token_tensor = decoder_start_token_tensor
@torch.no_grad()
def generate(
@ -1696,10 +1708,10 @@ class GenerationMixin:
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
generation_config._pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
@ -1716,7 +1728,7 @@ class GenerationMixin:
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
@ -1731,7 +1743,7 @@ class GenerationMixin:
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
device=inputs_tensor.device,
)
else:
@ -2279,7 +2291,7 @@ class GenerationMixin:
raise ValueError("DoLa decoding is only available for decoder-only models.")
# init values
pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
@ -2486,7 +2498,7 @@ class GenerationMixin:
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
top_k = generation_config.top_k
penalty_alpha = generation_config.penalty_alpha
pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
@ -2877,7 +2889,7 @@ class GenerationMixin:
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
@ -3084,8 +3096,8 @@ class GenerationMixin:
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
@ -3366,8 +3378,8 @@ class GenerationMixin:
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
@ -3658,8 +3670,8 @@ class GenerationMixin:
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores

View File

@ -1539,75 +1539,43 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
# 3. Define model inputs`
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, generation_config.pad_token_id, generation_config.eos_token_id
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
logger.warning(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation."
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
logger.warning(
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=input_ids,
input_ids_length=input_ids_length,
)
# 6. Prepare `input_ids` which will be used for auto-regressive generation
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
pad_token_id=generation_config._decoder_start_token_tensor,
max_length=generation_config.max_length,
)
@ -1628,7 +1596,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
input_ids_seq_length=input_ids_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
@ -1682,7 +1650,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
batch_size, self.num_codebooks, -1
)
@ -2590,39 +2558,23 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
if "encoder_outputs" not in model_kwargs:
@ -2642,45 +2594,28 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
bos_token_id=generation_config._bos_token_tensor,
device=inputs_tensor.device,
)
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
logger.warning(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation."
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
logger.warning(
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
pad_token_id=generation_config._decoder_start_token_tensor,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute in each forward pass
@ -2701,7 +2636,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
@ -2756,7 +2691,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
batch_size, self.decoder.num_codebooks, -1
)

View File

@ -1375,6 +1375,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
return input_ids
@torch.no_grad()
# Ignore copy
def generate(
self,
inputs: Optional[torch.Tensor] = None,
@ -1460,75 +1461,43 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
# 3. Define model inputs`
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
# Ignore copy
if model_kwargs.get("attention_mask", None) is None:
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, generation_config.pad_token_id, generation_config.eos_token_id
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
logger.warning(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation."
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
logger.warning(
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=input_ids,
input_ids_length=input_ids_length,
)
# 6. Prepare `input_ids` which will be used for auto-regressive generation
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
pad_token_id=generation_config._decoder_start_token_tensor,
max_length=generation_config.max_length,
)
@ -1549,7 +1518,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
input_ids_seq_length=input_ids_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
@ -1603,7 +1572,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
batch_size, self.num_codebooks, -1
)
@ -2397,7 +2366,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*):
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
@ -2414,18 +2383,14 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
- [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.GenerateBeamDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if generation_config is None:
@ -2440,37 +2405,23 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
if model_kwargs.get("attention_mask", None) is None:
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
if "encoder_hidden_states" not in model_kwargs:
@ -2484,46 +2435,28 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
bos_token_id=generation_config._bos_token_tensor,
device=inputs_tensor.device,
)
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
logger.warning(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation."
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
logger.warning(
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen Melody)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
pad_token_id=generation_config._decoder_start_token_tensor,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute in each forward pass
@ -2544,7 +2477,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
@ -2599,7 +2532,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
batch_size, self.decoder.num_codebooks, -1
)

View File

@ -3196,6 +3196,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
def test_special_tokens_fall_back_to_model_default(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
torch_device
)
test_bos_id = 50
# Sanity-check: the model has a BOS token set, and the first generated token is a BOS token
gen_output = model.generate()
self.assertTrue(model.generation_config.bos_token_id is not None)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
# If we pass a generation config **with** a BOS token, `generate` will use it
generation_config = GenerationConfig(bos_token_id=test_bos_id)
gen_output = model.generate(generation_config=generation_config)
self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(test_bos_id == gen_output[0, 0])
# If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
# `model.generation_config`
generation_config = GenerationConfig(bos_token_id=None)
gen_output = model.generate(generation_config=generation_config)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertFalse(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)
# Changing `model.generation_config` will affect fallback behavior
model.generation_config.bos_token_id = test_bos_id
gen_output = model.generate(generation_config=generation_config)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)
@require_torch
class TokenHealingTestCase(unittest.TestCase):