mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Improve greedy search memory usage (#32895)
Do not call torch.repeat_interleave if expand_size is 1
This commit is contained in:
parent
bf97d4aa6d
commit
99d67f1a09
@ -617,6 +617,10 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
||||
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
|
||||
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
|
||||
# the input tensor and thus requires more memory although no change is applied
|
||||
if expand_size == 1:
|
||||
return input_ids, model_kwargs
|
||||
|
||||
def _expand_dict_for_generation(dict_to_expand):
|
||||
for key in dict_to_expand:
|
||||
|
Loading…
Reference in New Issue
Block a user