Improve greedy search memory usage (#32895)

Do not call torch.repeat_interleave if expand_size is 1
This commit is contained in:
regisss 2024-08-22 16:37:44 +02:00 committed by GitHub
parent bf97d4aa6d
commit 99d67f1a09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -617,6 +617,10 @@ class GenerationMixin:
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
# the input tensor and thus requires more memory although no change is applied
if expand_size == 1:
return input_ids, model_kwargs
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand: