Generate: v4.38 removals and related updates (#29171)

This commit is contained in:
Joao Gante 2024-02-26 13:36:12 +00:00 committed by GitHub
parent 24d59c7969
commit ece1b62b93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 20 additions and 38 deletions

View File

@ -40,6 +40,11 @@ else:
"BeamSearchScorer",
"ConstrainedBeamSearchScorer",
]
_import_structure["candidate_generator"] = [
"AssistedCandidateGenerator",
"CandidateGenerator",
"PromptLookupCandidateGenerator",
]
_import_structure["logits_process"] = [
"AlternatingCodebooksLogitsProcessor",
"ClassifierFreeGuidanceLogitsProcessor",
@ -178,6 +183,7 @@ if TYPE_CHECKING:
else:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,

View File

@ -99,7 +99,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Make sure all data at the same device as assistant model
device = assistant_model.device
input_ids = input_ids.to(device)
inputs_tensor = inputs_tensor.to(device)
if inputs_tensor is not None:
inputs_tensor = inputs_tensor.to(device)
# Prepare the assistant and the starting number of candidate tokens
self.assistant_model = assistant_model

View File

@ -4319,7 +4319,6 @@ class GenerationMixin:
def assisted_decoding(
self,
input_ids: torch.LongTensor,
assistant_model: Optional["PreTrainedModel"] = None,
candidate_generator: Optional["CandidateGenerator"] = None,
do_sample: bool = False,
logits_processor: Optional[LogitsProcessorList] = None,
@ -4355,12 +4354,7 @@ class GenerationMixin:
The sequence used as a prompt for the generation.
candidate_generator (`CandidateGenerator`, *optional*):
A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function.
assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
is much faster than running generation with the model you're calling generate from. As such, the
assistant model should be much smaller.
more information, the documentation of [`CandidateGenerator`] should be read.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
logits_processor (`LogitsProcessorList`, *optional*):
@ -4417,6 +4411,7 @@ class GenerationMixin:
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
>>> from transformers.generation import AssistedCandidateGenerator
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
@ -4432,33 +4427,22 @@ class GenerationMixin:
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> candidate_generator = AssistedCandidateGenerator(
... input_ids=input_ids,
... assistant_model=assistant_model,
... generation_config=model.generation_config,
... logits_processor=logits_processor,
... model_kwargs={},
... )
>>> outputs = model.assisted_decoding(
... input_ids,
... assistant_model=assistant_model,
... candidate_generator=candidate_generator,
... logits_processor=logits_processor,
... stopping_criteria=stopping_criteria,
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# handling deprecated arguments
if (assistant_model is None) == (candidate_generator is None):
raise ValueError("One (and only one) of `assistant_model` and `candidate_generator` should be defined.")
if assistant_model is not None:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
eos_token_id=eos_token_id,
)
warnings.warn(
"Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. "
"Pass the `candidate_generator` argument instead.",
FutureWarning,
)
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()

View File

@ -129,8 +129,8 @@ class OPTAttention(nn.Module):
val = None
if fn_arg_name in kwargs:
logging.warning(
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
" Please set it in the config instead"
"Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
"v4.39. Please set it in the config instead"
)
val = kwargs.pop(fn_arg_name)
else:

View File

@ -120,7 +120,6 @@ from .import_utils import (
is_essentia_available,
is_faiss_available,
is_flash_attn_2_available,
is_flash_attn_available,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
is_fsdp_available,

View File

@ -665,14 +665,6 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_available():
logger.warning(
"Using `is_flash_attn_available` is deprecated and will be removed in v4.38. "
"Please use `is_flash_attn_2_available` instead."
)
return is_flash_attn_2_available()
def is_torchdistx_available():
return _torchdistx_available