diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index d1e81cffca6..e45f546cdc2 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -40,6 +40,11 @@ else: "BeamSearchScorer", "ConstrainedBeamSearchScorer", ] + _import_structure["candidate_generator"] = [ + "AssistedCandidateGenerator", + "CandidateGenerator", + "PromptLookupCandidateGenerator", + ] _import_structure["logits_process"] = [ "AlternatingCodebooksLogitsProcessor", "ClassifierFreeGuidanceLogitsProcessor", @@ -178,6 +183,7 @@ if TYPE_CHECKING: else: from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer + from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator from .logits_process import ( AlternatingCodebooksLogitsProcessor, ClassifierFreeGuidanceLogitsProcessor, diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 616afa19317..4b8fa144f04 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -99,7 +99,8 @@ class AssistedCandidateGenerator(CandidateGenerator): # Make sure all data at the same device as assistant model device = assistant_model.device input_ids = input_ids.to(device) - inputs_tensor = inputs_tensor.to(device) + if inputs_tensor is not None: + inputs_tensor = inputs_tensor.to(device) # Prepare the assistant and the starting number of candidate tokens self.assistant_model = assistant_model diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d337e559344..c7e03123a9e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4319,7 +4319,6 @@ class GenerationMixin: def assisted_decoding( self, input_ids: torch.LongTensor, - assistant_model: Optional["PreTrainedModel"] = None, candidate_generator: Optional["CandidateGenerator"] = None, do_sample: bool = False, logits_processor: Optional[LogitsProcessorList] = None, @@ -4355,12 +4354,7 @@ class GenerationMixin: The sequence used as a prompt for the generation. candidate_generator (`CandidateGenerator`, *optional*): A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For - more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. - assistant_model (`PreTrainedModel`, *optional*): - An assistant model that can be used to accelerate generation. The assistant model must have the exact - same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model - is much faster than running generation with the model you're calling generate from. As such, the - assistant model should be much smaller. + more information, the documentation of [`CandidateGenerator`] should be read. do_sample (`bool`, *optional*, defaults to `False`): Whether or not to use sampling ; use greedy decoding otherwise. logits_processor (`LogitsProcessorList`, *optional*): @@ -4417,6 +4411,7 @@ class GenerationMixin: ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) + >>> from transformers.generation import AssistedCandidateGenerator >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") @@ -4432,33 +4427,22 @@ class GenerationMixin: ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + >>> candidate_generator = AssistedCandidateGenerator( + ... input_ids=input_ids, + ... assistant_model=assistant_model, + ... generation_config=model.generation_config, + ... logits_processor=logits_processor, + ... model_kwargs={}, + ... ) >>> outputs = model.assisted_decoding( ... input_ids, - ... assistant_model=assistant_model, + ... candidate_generator=candidate_generator, ... logits_processor=logits_processor, ... stopping_criteria=stopping_criteria, ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" - # handling deprecated arguments - if (assistant_model is None) == (candidate_generator is None): - raise ValueError("One (and only one) of `assistant_model` and `candidate_generator` should be defined.") - - if assistant_model is not None: - candidate_generator = AssistedCandidateGenerator( - input_ids=input_ids, - assistant_model=assistant_model, - logits_processor=logits_processor, - model_kwargs=model_kwargs, - eos_token_id=eos_token_id, - ) - warnings.warn( - "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. " - "Pass the `candidate_generator` argument instead.", - FutureWarning, - ) - # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index d6f0924f427..7c66f5c255e 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -129,8 +129,8 @@ class OPTAttention(nn.Module): val = None if fn_arg_name in kwargs: logging.warning( - "Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38." - " Please set it in the config instead" + "Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from " + "v4.39. Please set it in the config instead" ) val = kwargs.pop(fn_arg_name) else: diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 3a3c65a3b7d..154077924be 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -120,7 +120,6 @@ from .import_utils import ( is_essentia_available, is_faiss_available, is_flash_attn_2_available, - is_flash_attn_available, is_flash_attn_greater_or_equal_2_10, is_flax_available, is_fsdp_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 8cf6c1a14f3..095af536621 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -665,14 +665,6 @@ def is_flash_attn_greater_or_equal_2_10(): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") -def is_flash_attn_available(): - logger.warning( - "Using `is_flash_attn_available` is deprecated and will be removed in v4.38. " - "Please use `is_flash_attn_2_available` instead." - ) - return is_flash_attn_2_available() - - def is_torchdistx_available(): return _torchdistx_available