mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Generate: v4.38 removals and related updates (#29171)
This commit is contained in:
parent
24d59c7969
commit
ece1b62b93
@ -40,6 +40,11 @@ else:
|
||||
"BeamSearchScorer",
|
||||
"ConstrainedBeamSearchScorer",
|
||||
]
|
||||
_import_structure["candidate_generator"] = [
|
||||
"AssistedCandidateGenerator",
|
||||
"CandidateGenerator",
|
||||
"PromptLookupCandidateGenerator",
|
||||
]
|
||||
_import_structure["logits_process"] = [
|
||||
"AlternatingCodebooksLogitsProcessor",
|
||||
"ClassifierFreeGuidanceLogitsProcessor",
|
||||
@ -178,6 +183,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
|
||||
from .logits_process import (
|
||||
AlternatingCodebooksLogitsProcessor,
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
|
@ -99,7 +99,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# Make sure all data at the same device as assistant model
|
||||
device = assistant_model.device
|
||||
input_ids = input_ids.to(device)
|
||||
inputs_tensor = inputs_tensor.to(device)
|
||||
if inputs_tensor is not None:
|
||||
inputs_tensor = inputs_tensor.to(device)
|
||||
|
||||
# Prepare the assistant and the starting number of candidate tokens
|
||||
self.assistant_model = assistant_model
|
||||
|
@ -4319,7 +4319,6 @@ class GenerationMixin:
|
||||
def assisted_decoding(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
candidate_generator: Optional["CandidateGenerator"] = None,
|
||||
do_sample: bool = False,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
@ -4355,12 +4354,7 @@ class GenerationMixin:
|
||||
The sequence used as a prompt for the generation.
|
||||
candidate_generator (`CandidateGenerator`, *optional*):
|
||||
A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
|
||||
more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function.
|
||||
assistant_model (`PreTrainedModel`, *optional*):
|
||||
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
||||
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
||||
is much faster than running generation with the model you're calling generate from. As such, the
|
||||
assistant model should be much smaller.
|
||||
more information, the documentation of [`CandidateGenerator`] should be read.
|
||||
do_sample (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use sampling ; use greedy decoding otherwise.
|
||||
logits_processor (`LogitsProcessorList`, *optional*):
|
||||
@ -4417,6 +4411,7 @@ class GenerationMixin:
|
||||
... StoppingCriteriaList,
|
||||
... MaxLengthCriteria,
|
||||
... )
|
||||
>>> from transformers.generation import AssistedCandidateGenerator
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
@ -4432,33 +4427,22 @@ class GenerationMixin:
|
||||
... ]
|
||||
... )
|
||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||
>>> candidate_generator = AssistedCandidateGenerator(
|
||||
... input_ids=input_ids,
|
||||
... assistant_model=assistant_model,
|
||||
... generation_config=model.generation_config,
|
||||
... logits_processor=logits_processor,
|
||||
... model_kwargs={},
|
||||
... )
|
||||
>>> outputs = model.assisted_decoding(
|
||||
... input_ids,
|
||||
... assistant_model=assistant_model,
|
||||
... candidate_generator=candidate_generator,
|
||||
... logits_processor=logits_processor,
|
||||
... stopping_criteria=stopping_criteria,
|
||||
... )
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
||||
```"""
|
||||
# handling deprecated arguments
|
||||
if (assistant_model is None) == (candidate_generator is None):
|
||||
raise ValueError("One (and only one) of `assistant_model` and `candidate_generator` should be defined.")
|
||||
|
||||
if assistant_model is not None:
|
||||
candidate_generator = AssistedCandidateGenerator(
|
||||
input_ids=input_ids,
|
||||
assistant_model=assistant_model,
|
||||
logits_processor=logits_processor,
|
||||
model_kwargs=model_kwargs,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
warnings.warn(
|
||||
"Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. "
|
||||
"Pass the `candidate_generator` argument instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
|
@ -129,8 +129,8 @@ class OPTAttention(nn.Module):
|
||||
val = None
|
||||
if fn_arg_name in kwargs:
|
||||
logging.warning(
|
||||
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
|
||||
" Please set it in the config instead"
|
||||
"Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
|
||||
"v4.39. Please set it in the config instead"
|
||||
)
|
||||
val = kwargs.pop(fn_arg_name)
|
||||
else:
|
||||
|
@ -120,7 +120,6 @@ from .import_utils import (
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_flax_available,
|
||||
is_fsdp_available,
|
||||
|
@ -665,14 +665,6 @@ def is_flash_attn_greater_or_equal_2_10():
|
||||
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
|
||||
|
||||
|
||||
def is_flash_attn_available():
|
||||
logger.warning(
|
||||
"Using `is_flash_attn_available` is deprecated and will be removed in v4.38. "
|
||||
"Please use `is_flash_attn_2_available` instead."
|
||||
)
|
||||
return is_flash_attn_2_available()
|
||||
|
||||
|
||||
def is_torchdistx_available():
|
||||
return _torchdistx_available
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user