mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: store special token tensors under a unique variable name (#31980)
* rename stuff * english; this one shouldn't be changed * add a _ to the new var names * musicgen * derp
This commit is contained in:
parent
aa8f86a421
commit
c38c55f4fb
@ -754,12 +754,12 @@ class GenerationMixin:
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
|
||||
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
||||
if generation_config.num_beams > 1:
|
||||
if isinstance(generation_config.eos_token_id, list):
|
||||
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
|
||||
elif isinstance(generation_config.eos_token_id, torch.Tensor):
|
||||
min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1
|
||||
if isinstance(generation_config._eos_token_tensor, list):
|
||||
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
||||
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
||||
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
||||
else:
|
||||
min_tokens_to_keep = 2
|
||||
else:
|
||||
@ -863,31 +863,31 @@ class GenerationMixin:
|
||||
processors.append(
|
||||
NoBadWordsLogitsProcessor(
|
||||
generation_config.bad_words_ids,
|
||||
generation_config.eos_token_id,
|
||||
generation_config._eos_token_tensor,
|
||||
)
|
||||
)
|
||||
if (
|
||||
generation_config.min_length is not None
|
||||
and generation_config.eos_token_id is not None
|
||||
and generation_config._eos_token_tensor is not None
|
||||
and generation_config.min_length > 0
|
||||
):
|
||||
processors.append(
|
||||
MinLengthLogitsProcessor(
|
||||
generation_config.min_length,
|
||||
generation_config.eos_token_id,
|
||||
generation_config._eos_token_tensor,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
if (
|
||||
generation_config.min_new_tokens is not None
|
||||
and generation_config.eos_token_id is not None
|
||||
and generation_config._eos_token_tensor is not None
|
||||
and generation_config.min_new_tokens > 0
|
||||
):
|
||||
processors.append(
|
||||
MinNewTokensLengthLogitsProcessor(
|
||||
input_ids_seq_length,
|
||||
generation_config.min_new_tokens,
|
||||
generation_config.eos_token_id,
|
||||
generation_config._eos_token_tensor,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
@ -918,7 +918,7 @@ class GenerationMixin:
|
||||
processors.append(
|
||||
ExponentialDecayLengthPenalty(
|
||||
generation_config.exponential_decay_length_penalty,
|
||||
generation_config.eos_token_id,
|
||||
generation_config._eos_token_tensor,
|
||||
input_ids_seq_length,
|
||||
)
|
||||
)
|
||||
@ -997,8 +997,8 @@ class GenerationMixin:
|
||||
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
|
||||
)
|
||||
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
|
||||
if generation_config.eos_token_id is not None:
|
||||
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
|
||||
if generation_config._eos_token_tensor is not None:
|
||||
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
|
||||
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
||||
return criteria
|
||||
|
||||
@ -1349,13 +1349,15 @@ class GenerationMixin:
|
||||
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
|
||||
) -> Tuple[GenerationConfig, Dict]:
|
||||
"""
|
||||
Prepares the base generation config, then applies any generation configuration options from kwargs.
|
||||
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
||||
function handles retrocompatibility with respect to configuration files.
|
||||
"""
|
||||
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
|
||||
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
|
||||
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.
|
||||
|
||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||
using_model_generation_config = False
|
||||
if generation_config is None:
|
||||
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||
# three conditions must be met
|
||||
@ -1378,6 +1380,7 @@ class GenerationMixin:
|
||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||
)
|
||||
self.generation_config = new_generation_config
|
||||
using_model_generation_config = True
|
||||
generation_config = self.generation_config
|
||||
|
||||
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
|
||||
@ -1395,6 +1398,16 @@ class GenerationMixin:
|
||||
else:
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
|
||||
if not using_model_generation_config:
|
||||
if generation_config.bos_token_id is None:
|
||||
generation_config.bos_token_id = self.generation_config.bos_token_id
|
||||
if generation_config.eos_token_id is None:
|
||||
generation_config.eos_token_id = self.generation_config.eos_token_id
|
||||
if generation_config.pad_token_id is None:
|
||||
generation_config.pad_token_id = self.generation_config.pad_token_id
|
||||
if generation_config.decoder_start_token_id is None:
|
||||
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
|
||||
|
||||
return generation_config, model_kwargs
|
||||
|
||||
@ -1493,52 +1506,46 @@ class GenerationMixin:
|
||||
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
||||
"""
|
||||
|
||||
# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
|
||||
def _tensor_or_none(token_kwargs, token_self, device=None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
token = token_kwargs if token_kwargs is not None else token_self
|
||||
# Convert special tokens to tensors
|
||||
def _tensor_or_none(token, device=None):
|
||||
if token is None:
|
||||
return token
|
||||
elif isinstance(token, torch.Tensor):
|
||||
return token.to(device)
|
||||
|
||||
device = device if device is not None else self.device
|
||||
if isinstance(token, torch.Tensor):
|
||||
return token.to(device)
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
bos_token_id = _tensor_or_none(
|
||||
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
|
||||
)
|
||||
eos_token_id = _tensor_or_none(
|
||||
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
|
||||
)
|
||||
pad_token_id = _tensor_or_none(
|
||||
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
|
||||
)
|
||||
decoder_start_token_id = _tensor_or_none(
|
||||
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
|
||||
)
|
||||
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
|
||||
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
|
||||
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
|
||||
decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
|
||||
|
||||
# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
|
||||
if self.config.is_encoder_decoder:
|
||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||
decoder_start_token_tensor = (
|
||||
decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor
|
||||
)
|
||||
|
||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||
if eos_token_id is not None and eos_token_id.ndim == 0:
|
||||
eos_token_id = eos_token_id.unsqueeze(0)
|
||||
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
|
||||
eos_token_tensor = eos_token_tensor.unsqueeze(0)
|
||||
|
||||
# Set pad token if unset (and there are conditions to do so)
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
if pad_token_tensor is None and eos_token_tensor is not None:
|
||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
pad_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
|
||||
pad_token_tensor = eos_token_tensor[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
|
||||
|
||||
# we can't infer attn mask if pad token is set to be eos token in model's generation config
|
||||
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
|
||||
if (
|
||||
eos_token_tensor is not None
|
||||
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
||||
):
|
||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||
logger.warning_once(
|
||||
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
|
||||
@ -1547,21 +1554,26 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# Sanity checks/warnings
|
||||
if self.config.is_encoder_decoder and decoder_start_token_id is None:
|
||||
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
|
||||
raise ValueError(
|
||||
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
||||
)
|
||||
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
|
||||
if eos_token_tensor is not None and (
|
||||
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
|
||||
):
|
||||
logger.warning(
|
||||
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
|
||||
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not "
|
||||
"stop until the maximum length is reached. Depending on other flags, it may even crash."
|
||||
)
|
||||
|
||||
# Update generation config with the updated special tokens tensors
|
||||
generation_config.bos_token_id = bos_token_id
|
||||
generation_config.eos_token_id = eos_token_id
|
||||
generation_config.pad_token_id = pad_token_id
|
||||
generation_config.decoder_start_token_id = decoder_start_token_id
|
||||
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
|
||||
# (in their non-tensor form), in order to enable end-to-end compilation. See
|
||||
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
|
||||
generation_config._bos_token_tensor = bos_token_tensor
|
||||
generation_config._eos_token_tensor = eos_token_tensor
|
||||
generation_config._pad_token_tensor = pad_token_tensor
|
||||
generation_config._decoder_start_token_tensor = decoder_start_token_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
@ -1696,10 +1708,10 @@ class GenerationMixin:
|
||||
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
||||
if (
|
||||
generation_config.pad_token_id is not None
|
||||
generation_config._pad_token_tensor is not None
|
||||
and batch_size > 1
|
||||
and len(inputs_tensor.shape) == 2
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
|
||||
):
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
@ -1716,7 +1728,7 @@ class GenerationMixin:
|
||||
|
||||
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
||||
)
|
||||
|
||||
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
||||
@ -1731,7 +1743,7 @@ class GenerationMixin:
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
else:
|
||||
@ -2279,7 +2291,7 @@ class GenerationMixin:
|
||||
raise ValueError("DoLa decoding is only available for decoder-only models.")
|
||||
# init values
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -2486,7 +2498,7 @@ class GenerationMixin:
|
||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||
top_k = generation_config.top_k
|
||||
penalty_alpha = generation_config.penalty_alpha
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -2877,7 +2889,7 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -3084,8 +3096,8 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
eos_token_id = generation_config._eos_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -3366,8 +3378,8 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
eos_token_id = generation_config._eos_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -3658,8 +3670,8 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
eos_token_id = generation_config._eos_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
|
@ -1539,75 +1539,43 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
|
||||
# 3. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
# model_input_name is defined if model-specific keyword input is passed
|
||||
# otherwise model_input_name is None
|
||||
# all model-specific keyword inputs are removed from `model_kwargs`
|
||||
# 3. Define model inputs`
|
||||
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = input_ids.shape[0] // self.num_codebooks
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
model_kwargs["guidance_scale"] = generation_config.guidance_scale
|
||||
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
input_ids, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
||||
)
|
||||
|
||||
# 5. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
||||
logger.warning(
|
||||
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
|
||||
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation."
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
||||
f" the maximum length ({generation_config.max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
logger.warning(
|
||||
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
has_default_min_length=has_default_min_length,
|
||||
model_input_name=model_input_name,
|
||||
inputs_tensor=input_ids,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
# 6. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
|
||||
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
|
||||
input_ids,
|
||||
pad_token_id=generation_config.decoder_start_token_id,
|
||||
pad_token_id=generation_config._decoder_start_token_tensor,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
|
||||
@ -1628,7 +1596,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
# 9. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
@ -1682,7 +1650,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
|
||||
|
||||
# revert the pattern delay mask by filtering the pad token id
|
||||
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
|
||||
output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
|
||||
batch_size, self.num_codebooks, -1
|
||||
)
|
||||
|
||||
@ -2590,39 +2558,23 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
|
||||
# 3. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
# model_input_name is defined if model-specific keyword input is passed
|
||||
# otherwise model_input_name is None
|
||||
# all model-specific keyword inputs are removed from `model_kwargs`
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
model_kwargs["guidance_scale"] = generation_config.guidance_scale
|
||||
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
||||
)
|
||||
|
||||
if "encoder_outputs" not in model_kwargs:
|
||||
@ -2642,45 +2594,28 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
||||
bos_token_id=generation_config._bos_token_tensor,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
logger.warning(
|
||||
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
|
||||
"to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation."
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
||||
f" the maximum length ({generation_config.max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
logger.warning(
|
||||
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
has_default_min_length=has_default_min_length,
|
||||
model_input_name=model_input_name,
|
||||
inputs_tensor=inputs_tensor,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
|
||||
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
|
||||
input_ids,
|
||||
pad_token_id=generation_config.decoder_start_token_id,
|
||||
pad_token_id=generation_config._decoder_start_token_tensor,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
# stash the delay mask so that we don't have to recompute in each forward pass
|
||||
@ -2701,7 +2636,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
# 9. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
@ -2756,7 +2691,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
|
||||
|
||||
# revert the pattern delay mask by filtering the pad token id
|
||||
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
|
||||
output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
|
||||
batch_size, self.decoder.num_codebooks, -1
|
||||
)
|
||||
|
||||
|
@ -1375,6 +1375,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
return input_ids
|
||||
|
||||
@torch.no_grad()
|
||||
# Ignore copy
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
@ -1460,75 +1461,43 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
|
||||
# 3. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
# model_input_name is defined if model-specific keyword input is passed
|
||||
# otherwise model_input_name is None
|
||||
# all model-specific keyword inputs are removed from `model_kwargs`
|
||||
# 3. Define model inputs`
|
||||
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = input_ids.shape[0] // self.num_codebooks
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
model_kwargs["guidance_scale"] = generation_config.guidance_scale
|
||||
|
||||
# Ignore copy
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
input_ids, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
||||
)
|
||||
|
||||
# 5. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
||||
logger.warning(
|
||||
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
|
||||
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation."
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
||||
f" the maximum length ({generation_config.max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
logger.warning(
|
||||
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
has_default_min_length=has_default_min_length,
|
||||
model_input_name=model_input_name,
|
||||
inputs_tensor=input_ids,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
# 6. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
|
||||
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen)
|
||||
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
|
||||
input_ids,
|
||||
pad_token_id=generation_config.decoder_start_token_id,
|
||||
pad_token_id=generation_config._decoder_start_token_tensor,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
|
||||
@ -1549,7 +1518,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
# 9. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
@ -1603,7 +1572,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
|
||||
|
||||
# revert the pattern delay mask by filtering the pad token id
|
||||
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
|
||||
output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
|
||||
batch_size, self.num_codebooks, -1
|
||||
)
|
||||
|
||||
@ -2397,7 +2366,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
synced_gpus (`bool`, *optional*):
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
@ -2414,18 +2383,14 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
|
||||
- [`~generation.GreedySearchDecoderOnlyOutput`],
|
||||
- [`~generation.SampleDecoderOnlyOutput`],
|
||||
- [`~generation.BeamSearchDecoderOnlyOutput`],
|
||||
- [`~generation.BeamSampleDecoderOnlyOutput`]
|
||||
- [`~generation.GenerateDecoderOnlyOutput`],
|
||||
- [`~generation.GenerateBeamDecoderOnlyOutput`]
|
||||
|
||||
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
|
||||
- [`~generation.GreedySearchEncoderDecoderOutput`],
|
||||
- [`~generation.SampleEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSearchEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSampleEncoderDecoderOutput`]
|
||||
- [`~generation.GenerateEncoderDecoderOutput`],
|
||||
- [`~generation.GenerateBeamEncoderDecoderOutput`]
|
||||
"""
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
|
||||
if generation_config is None:
|
||||
@ -2440,37 +2405,23 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
|
||||
# 3. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
# model_input_name is defined if model-specific keyword input is passed
|
||||
# otherwise model_input_name is None
|
||||
# all model-specific keyword inputs are removed from `model_kwargs`
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
model_kwargs["guidance_scale"] = generation_config.guidance_scale
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
||||
)
|
||||
|
||||
if "encoder_hidden_states" not in model_kwargs:
|
||||
@ -2484,46 +2435,28 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
||||
bos_token_id=generation_config._bos_token_tensor,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
logger.warning(
|
||||
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
|
||||
"to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation."
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
has_default_min_length=has_default_min_length,
|
||||
model_input_name=model_input_name,
|
||||
inputs_tensor=inputs_tensor,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
||||
f" the maximum length ({generation_config.max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
logger.warning(
|
||||
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
|
||||
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen Melody)
|
||||
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
|
||||
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
|
||||
input_ids,
|
||||
pad_token_id=generation_config.decoder_start_token_id,
|
||||
pad_token_id=generation_config._decoder_start_token_tensor,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
# stash the delay mask so that we don't have to recompute in each forward pass
|
||||
@ -2544,7 +2477,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
# 9. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
@ -2599,7 +2532,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
|
||||
|
||||
# revert the pattern delay mask by filtering the pad token id
|
||||
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
|
||||
output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
|
||||
batch_size, self.decoder.num_codebooks, -1
|
||||
)
|
||||
|
||||
|
@ -3196,6 +3196,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||
|
||||
def test_special_tokens_fall_back_to_model_default(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
torch_device
|
||||
)
|
||||
test_bos_id = 50
|
||||
|
||||
# Sanity-check: the model has a BOS token set, and the first generated token is a BOS token
|
||||
gen_output = model.generate()
|
||||
self.assertTrue(model.generation_config.bos_token_id is not None)
|
||||
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
|
||||
|
||||
# If we pass a generation config **with** a BOS token, `generate` will use it
|
||||
generation_config = GenerationConfig(bos_token_id=test_bos_id)
|
||||
gen_output = model.generate(generation_config=generation_config)
|
||||
self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id == gen_output[0, 0])
|
||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||
|
||||
# If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
|
||||
# `model.generation_config`
|
||||
generation_config = GenerationConfig(bos_token_id=None)
|
||||
gen_output = model.generate(generation_config=generation_config)
|
||||
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
|
||||
self.assertFalse(test_bos_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id is None)
|
||||
|
||||
# Changing `model.generation_config` will affect fallback behavior
|
||||
model.generation_config.bos_token_id = test_bos_id
|
||||
gen_output = model.generate(generation_config=generation_config)
|
||||
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
|
||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id is None)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user