mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
739a63166d
commit
df1c248a6d
@ -2218,8 +2218,8 @@ class GenerationMixin:
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
streamer: "BaseStreamer",
|
||||
logits_warper: LogitsProcessorList,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2818,34 +2818,6 @@ class GenerationMixin:
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
def _greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: LogitsProcessorList,
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
streamer: Optional["BaseStreamer"],
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Deprecated. Use `._sample()` instead, passing the same arguments.
|
||||
"""
|
||||
|
||||
logger.warning_once(
|
||||
"Calling `._greedy_search()` directly is deprecated and will be removed in v4.42. Use `._sample()` "
|
||||
"instead, passing the same arguments."
|
||||
)
|
||||
return self._sample(
|
||||
input_ids=input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -2854,7 +2826,7 @@ class GenerationMixin:
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
streamer: Optional["BaseStreamer"],
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
logits_warper: LogitsProcessorList,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -3053,7 +3025,6 @@ class GenerationMixin:
|
||||
past_key_values.reorder_cache(beam_idx)
|
||||
return past_key_values
|
||||
|
||||
# TODO (joao, v4.42): remove default for `logits_warper`
|
||||
def _beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -3062,7 +3033,7 @@ class GenerationMixin:
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
logits_warper: LogitsProcessorList,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -3342,36 +3313,6 @@ class GenerationMixin:
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def _beam_sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
logits_processor: LogitsProcessorList,
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
logits_warper: LogitsProcessorList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Deprecated. Use `._beam_search()` instead, passing the same arguments.
|
||||
"""
|
||||
|
||||
logger.warning_once(
|
||||
"Calling `._beam_sample()` directly is deprecated and will be removed in v4.42. Use `._beam_search()` "
|
||||
"instead, passing the same arguments."
|
||||
)
|
||||
return self._beam_search(
|
||||
input_ids=input_ids,
|
||||
beam_scorer=beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logits_warper=logits_warper,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def _group_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
|
@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
"""Llava model configuration"""
|
||||
|
||||
import warnings
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
@ -96,12 +94,6 @@ class LlavaConfig(PretrainedConfig):
|
||||
f"Got: {vision_feature_select_strategy}"
|
||||
)
|
||||
|
||||
if "vocab_size" in kwargs:
|
||||
warnings.warn(
|
||||
"The `vocab_size` argument is deprecated and will be removed in v4.42, since it can be inferred from the `text_config`. Passing this argument has no effect",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user