mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Handle padding warning in generation when using inputs_embeds
(#23131)
* Handle padding warning in generation when using `inputs_embeds` * Simpler condition * Black formatter * Changed warning logic
This commit is contained in:
parent
65d7b21b77
commit
291c5e9b25
@ -1307,8 +1307,11 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# decoder-only models should use left-padding for generation
|
# decoder-only models should use left-padding for generation
|
||||||
if not self.config.is_encoder_decoder:
|
if not self.config.is_encoder_decoder:
|
||||||
|
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||||
|
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
||||||
if (
|
if (
|
||||||
generation_config.pad_token_id is not None
|
generation_config.pad_token_id is not None
|
||||||
|
and len(inputs_tensor.shape) == 2
|
||||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
Loading…
Reference in New Issue
Block a user