From c3c39f7e84762685b3f1bcf8d3eec15fc6f4c13f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Jun 2021 09:43:54 +0100 Subject: [PATCH] [Flax] Add Beam Search (#12131) * fix_torch_device_generate_test * remove @ * push new logit processors * add processors * save first working version * save intermediate * finish * make style * make fix-copies * finish * Update tests/test_modeling_flax_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Suraj Patil Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil --- docs/source/internal/generation_utils.rst | 9 + src/transformers/__init__.py | 6 + .../generation_flax_logits_process.py | 113 ++++- src/transformers/generation_flax_utils.py | 432 ++++++++++++++++-- src/transformers/utils/dummy_flax_objects.py | 27 ++ tests/test_generation_flax_logits_process.py | 160 ++++++- tests/test_generation_flax_utils.py | 77 ++++ tests/test_modeling_flax_bart.py | 60 +++ 8 files changed, 840 insertions(+), 44 deletions(-) diff --git a/docs/source/internal/generation_utils.rst b/docs/source/internal/generation_utils.rst index 04543a48be1..00dff63d983 100644 --- a/docs/source/internal/generation_utils.rst +++ b/docs/source/internal/generation_utils.rst @@ -186,6 +186,15 @@ generation. .. autoclass:: transformers.FlaxTopKLogitsWarper :members: __call__ +.. autoclass:: transformers.FlaxForcedBOSTokenLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.FlaxForcedEOSTokenLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.FlaxMinLengthLogitsProcessor + :members: __call__ + StoppingCriteria ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d46011f34c4..f244d167535 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1486,9 +1486,12 @@ else: # FLAX-backed objects if is_flax_available(): _import_structure["generation_flax_logits_process"] = [ + "FlaxForcedBOSTokenLogitsProcessor", + "FlaxForcedEOSTokenLogitsProcessor", "FlaxLogitsProcessor", "FlaxLogitsProcessorList", "FlaxLogitsWarper", + "FlaxMinLengthLogitsProcessor", "FlaxTemperatureLogitsWarper", "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", @@ -2814,9 +2817,12 @@ if TYPE_CHECKING: if is_flax_available(): from .generation_flax_logits_process import ( + FlaxForcedBOSTokenLogitsProcessor, + FlaxForcedEOSTokenLogitsProcessor, FlaxLogitsProcessor, FlaxLogitsProcessorList, FlaxLogitsWarper, + FlaxMinLengthLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, diff --git a/src/transformers/generation_flax_logits_process.py b/src/transformers/generation_flax_logits_process.py index da4e77715cf..c6179e63fc7 100644 --- a/src/transformers/generation_flax_logits_process.py +++ b/src/transformers/generation_flax_logits_process.py @@ -81,16 +81,18 @@ class FlaxLogitsProcessorList(list): """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray: + def __call__( + self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs + ) -> jax_xla.DeviceArray: for processor in self: function_args = inspect.signature(processor.__call__).parameters - if len(function_args) > 2: + if len(function_args) > 3: assert all( arg in kwargs for arg in list(function_args.keys())[2:] ), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." - scores = processor(input_ids, scores, **kwargs) + scores = processor(input_ids, scores, cur_len, **kwargs) else: - scores = processor(input_ids, scores) + scores = processor(input_ids, scores, cur_len) return scores @@ -109,7 +111,9 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): self.temperature = temperature - def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + def __call__( + self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int + ) -> jax_xla.DeviceArray: scores = scores / self.temperature return scores @@ -137,7 +141,9 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep - def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + def __call__( + self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int + ) -> jax_xla.DeviceArray: topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1]) mask_scores = jnp.full_like(scores, self.filter_value) @@ -177,7 +183,9 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper): self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep - def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + def __call__( + self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int + ) -> jax_xla.DeviceArray: batch_size, vocab_size = scores.shape next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value) @@ -190,3 +198,94 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper): next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat) next_scores = next_scores_flat.reshape(batch_size, vocab_size) return next_scores + + +class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor): + r""" + :class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the first generated token. + + Args: + bos_token_id (:obj:`int`): + The id of the token to force as the first generated token. + """ + + def __init__(self, bos_token_id: int): + self.bos_token_id = bos_token_id + + def __call__( + self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int + ) -> jax_xla.DeviceArray: + new_scores = jnp.full(scores.shape, -float("inf")) + + apply_penalty = 1 - jnp.bool_(cur_len - 1) + + scores = jnp.where( + apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores + ) + + return scores + + +class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor): + r""" + :class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the last generated token when + :obj:`max_length` is reached. + + Args: + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + eos_token_id (:obj:`int`): + The id of the token to force as the last generated token when :obj:`max_length` is reached. + """ + + def __init__(self, max_length: int, eos_token_id: int): + self.max_length = max_length + self.eos_token_id = eos_token_id + + def __call__( + self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int + ) -> jax_xla.DeviceArray: + new_scores = jnp.full(scores.shape, -float("inf")) + + apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) + + scores = jnp.where( + apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores + ) + + return scores + + +class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): + r""" + :class:`transformers.FlaxLogitsProcessor` enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (:obj:`int`): + The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, min_length: int, eos_token_id: int): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") + + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__( + self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int + ) -> jax_xla.DeviceArray: + + # create boolean flag to decide if min length penalty should be applied + apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) + + scores = jnp.where( + apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores + ) + + return scores diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index a22bf7c3f62..21889d59c5f 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -17,6 +17,8 @@ from typing import Dict, Optional +import numpy as np + import flax import jax import jax.numpy as jnp @@ -25,7 +27,10 @@ from jax import lax from .file_utils import ModelOutput from .generation_flax_logits_process import ( + FlaxForcedBOSTokenLogitsProcessor, + FlaxForcedEOSTokenLogitsProcessor, FlaxLogitsProcessorList, + FlaxMinLengthLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, @@ -43,9 +48,8 @@ class FlaxGreedySearchOutput(ModelOutput): Args: - sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): - The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is - padded to :obj:`max_length`. + sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): + The generated sequences. """ sequences: jax_xla.DeviceArray = None @@ -58,19 +62,35 @@ class FlaxSampleOutput(ModelOutput): Args: - sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`): - The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is - padded to :obj:`max_length`. + sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): + The generated sequences. """ sequences: jax_xla.DeviceArray = None +@flax.struct.dataclass +class FlaxBeamSearchOutput(ModelOutput): + """ + Flax Base class for outputs of decoder-only generation models using greedy search. + + + Args: + sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): + The generated sequences. + scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`): + The scores (log probabilites) of the generated sequences. + """ + + sequences: jax_xla.DeviceArray = None + scores: jax_xla.DeviceArray = None + + @flax.struct.dataclass class GreedyState: cur_len: jax_xla.DeviceArray sequences: jax_xla.DeviceArray - current_token: jax_xla.DeviceArray + running_token: jax_xla.DeviceArray is_sent_finished: jax_xla.DeviceArray model_kwargs: Dict[str, jax_xla.DeviceArray] @@ -79,12 +99,23 @@ class GreedyState: class SampleState: cur_len: jax_xla.DeviceArray sequences: jax_xla.DeviceArray - current_token: jax_xla.DeviceArray + running_token: jax_xla.DeviceArray is_sent_finished: jax_xla.DeviceArray prng_key: jax_xla.DeviceArray model_kwargs: Dict[str, jax_xla.DeviceArray] +@flax.struct.dataclass +class BeamSearchState: + cur_len: jax_xla.DeviceArray + running_sequences: jax_xla.DeviceArray + running_scores: jax_xla.DeviceArray + sequences: jax_xla.DeviceArray + scores: jax_xla.DeviceArray + is_sent_finished: jax_xla.DeviceArray + model_kwargs: Dict[str, jax_xla.DeviceArray] + + class FlaxGenerationMixin: """ A class containing all of the functions supporting generation, to be used as a mixin in @@ -110,6 +141,10 @@ class FlaxGenerationMixin: model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs) return model_kwargs + @staticmethod + def _expand_to_num_beams(tensor, num_beams): + return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) + def generate( self, input_ids: jax_xla.DeviceArray, @@ -123,6 +158,13 @@ class FlaxGenerationMixin: top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None, + num_beams: Optional[int] = None, + no_repeat_ngram_size: Optional[int] = None, + min_length: Optional[int] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + early_stopping: Optional[bool] = None, trace: bool = True, params: Optional[Dict[str, jax_xla.DeviceArray]] = None, **model_kwargs, @@ -159,6 +201,8 @@ class FlaxGenerationMixin: The id of the `beginning-of-sequence` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. + num_beams (:obj:`int`, `optional`, defaults to 1): + Number of beams for beam search. 1 means no beam search. decoder_start_token_id (:obj:`int`, `optional`): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. trace (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -204,9 +248,27 @@ class FlaxGenerationMixin: input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id do_sample = do_sample if do_sample is not None else self.config.do_sample + num_beams = num_beams if num_beams is not None else self.config.num_beams - if do_sample: + if not do_sample and num_beams == 1: + logits_processor = self._get_logits_processor( + no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + ) + return self._greedy_search( + input_ids, + max_length, + pad_token_id, + eos_token_id, + logits_processor=logits_processor, + trace=trace, + params=params, + model_kwargs=model_kwargs, + ) + elif do_sample and num_beams == 1: logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) + logits_processor = self._get_logits_processor( + no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + ) return self._sample( input_ids, max_length, @@ -214,20 +276,43 @@ class FlaxGenerationMixin: eos_token_id, prng_key, logits_warper=logits_warper, + logits_processor=logits_processor, + trace=trace, + params=params, + model_kwargs=model_kwargs, + ) + elif not do_sample and num_beams > 1: + # broadcast input_ids & encoder_outputs + input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) + + if "encoder_outputs" in model_kwargs: + model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( + model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams + ) + + if "attention_mask" in model_kwargs: + model_kwargs["attention_mask"] = self._expand_to_num_beams( + model_kwargs["attention_mask"], num_beams=num_beams + ) + + logits_processor = self._get_logits_processor( + no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + ) + + return self._beam_search( + input_ids, + max_length, + pad_token_id, + eos_token_id, + length_penalty=length_penalty, + early_stopping=early_stopping, + logits_processor=logits_processor, trace=trace, params=params, model_kwargs=model_kwargs, ) else: - return self._greedy_search( - input_ids, - max_length, - pad_token_id, - eos_token_id, - trace=trace, - params=params, - model_kwargs=model_kwargs, - ) + raise NotImplementedError("`Beam sampling is currently not implemented.") def _get_logits_warper( self, top_k: int = None, top_p: float = None, temperature: float = None @@ -255,12 +340,51 @@ class FlaxGenerationMixin: return warpers + def _get_logits_processor( + self, + no_repeat_ngram_size: int, + min_length: int, + max_length: int, + eos_token_id: int, + forced_bos_token_id: int, + forced_eos_token_id: int, + ) -> FlaxLogitsProcessorList: + """ + This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant + :obj:`~transformers.FlaxLogitsProcessor` instances used to modify the scores of the language model head. + """ + processors = FlaxLogitsProcessorList() + + # init warp parameters + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + min_length = min_length if min_length is not None else self.config.min_length + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + forced_bos_token_id = ( + forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id + ) + forced_eos_token_id = ( + forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id + ) + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if min_length is not None and eos_token_id is not None and min_length > -1: + processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id)) + if forced_bos_token_id is not None: + processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id)) + if forced_eos_token_id is not None: + processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) + return processors + def _greedy_search( self, input_ids: None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, + logits_processor: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, params: Optional[Dict[str, jax_xla.DeviceArray]] = None, model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, @@ -293,7 +417,7 @@ class FlaxGenerationMixin: state = GreedyState( cur_len=cur_len, sequences=sequences, - current_token=input_ids, + running_token=input_ids, is_sent_finished=is_sent_finished, model_kwargs=model_kwargs, ) @@ -307,8 +431,13 @@ class FlaxGenerationMixin: def greedy_search_body_fn(state): """state update fn.""" - model_outputs = model(state.current_token, params=params, **state.model_kwargs) - next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1) + model_outputs = model(state.running_token, params=params, **state.model_kwargs) + logits = model_outputs.logits[:, -1] + + # apply min_length, ... + logits = logits_processor(state.sequences, logits, state.cur_len) + + next_token = jnp.argmax(logits, axis=-1) next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished @@ -319,7 +448,7 @@ class FlaxGenerationMixin: return GreedyState( cur_len=state.cur_len + 1, sequences=next_sequences, - current_token=next_token, + running_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, ) @@ -342,6 +471,7 @@ class FlaxGenerationMixin: pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, prng_key: Optional[jax_xla.DeviceArray] = None, + logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_warper: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, params: Optional[Dict[str, jax_xla.DeviceArray]] = None, @@ -377,7 +507,7 @@ class FlaxGenerationMixin: state = SampleState( cur_len=cur_len, sequences=sequences, - current_token=input_ids, + running_token=input_ids, is_sent_finished=is_sent_finished, prng_key=prng_key, model_kwargs=model_kwargs, @@ -393,12 +523,14 @@ class FlaxGenerationMixin: def sample_search_body_fn(state): """state update fn.""" prng_key, prng_key_next = jax.random.split(state.prng_key) - model_outputs = model(state.current_token, params=params, **state.model_kwargs) + model_outputs = model(state.running_token, params=params, **state.model_kwargs) logits = model_outputs.logits[:, -1] + # apply min_length, ... + logits = logits_processor(state.sequences, logits, state.cur_len) # apply top_k, top_k, temperature - logits = logits_warper(state.sequences, logits) + logits = logits_warper(logits, logits, state.cur_len) next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1) @@ -412,7 +544,7 @@ class FlaxGenerationMixin: return SampleState( cur_len=state.cur_len + 1, sequences=next_sequences, - current_token=next_token, + running_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, prng_key=prng_key_next, @@ -428,3 +560,251 @@ class FlaxGenerationMixin: state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state) return FlaxSampleOutput(sequences=state.sequences) + + def _beam_search( + self, + input_ids: None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + early_stopping: Optional[bool] = None, + logits_processor: Optional[FlaxLogitsProcessorList] = None, + trace: bool = True, + params: Optional[Dict[str, jax_xla.DeviceArray]] = None, + model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, + ): + """ + This beam search function is heavily inspired by Flax's official example: + https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 + """ + + def flatten_beam_dim(tensor): + """Flattens the first two dimensions of a non-scalar array.""" + # ignore scalars (e.g. cache index) + if tensor.ndim == 0: + return tensor + return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) + + def unflatten_beam_dim(tensor, batch_size, num_beams): + """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" + # ignore scalars (e.g. cache index) + if tensor.ndim == 0: + return tensor + return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) + + def gather_beams(nested, beam_indices, batch_size, new_num_beams): + """ + Gathers the beam slices indexed by beam_indices into new beam array. + """ + batch_indices = jnp.reshape( + jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams) + ) + + def gather_fn(tensor): + # ignore scalars (e.g. cache index) + if tensor.ndim == 0: + return tensor + else: + return tensor[batch_indices, beam_indices] + + return jax.tree_map(gather_fn, nested) + + # init values + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + + batch_size, num_beams, cur_len = input_ids.shape + + eos_token_id = jnp.array(eos_token_id) + pad_token_id = jnp.array(pad_token_id) + cur_len = jnp.array(cur_len) + + # per batch,beam-item holding current token in loop. + sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) + running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) + running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) + + # per batch,beam-item state bit indicating if sentence has finished. + is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) + + # per batch,beam-item score, logprobs + running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]) + scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) + + # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop + # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. + model = self.decode if self.config.is_encoder_decoder else self + + # flatten beam dim + if "encoder_outputs" in model_kwargs: + model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( + model_kwargs["encoder_outputs"]["last_hidden_state"] + ) + if "attention_mask" in model_kwargs: + model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) + + # initialize model specific kwargs + model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs) + + # initialize state + state = BeamSearchState( + cur_len=cur_len, + running_sequences=running_sequences, + running_scores=running_scores, + sequences=sequences, + scores=scores, + is_sent_finished=is_sent_finished, + model_kwargs=model_kwargs, + ) + + def beam_search_cond_fn(state): + """beam search state termination condition fn.""" + + # 1. is less than max length? + not_max_length_yet = state.cur_len < max_length + + # 2. can the new beams still improve? + best_running_score = state.running_scores[:, -1:] / (max_length ** length_penalty) + worst_finished_score = jnp.where( + state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) + ) + improvement_still_possible = jnp.all(worst_finished_score < best_running_score) + + # 3. is there still a beam that has not finished? + still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) + + return not_max_length_yet & still_open_beam & improvement_still_possible + + def beam_search_body_fn(state): + """beam search state update fn.""" + # 1. Forward current tokens + # Collect the current position slice along length to feed the fast + # autoregressive decoder model. Flatten the beam dimension into batch + # dimension for feeding into the model. + # unflatten beam dimension + # Unflatten beam dimension in attention cache arrays + input_token = flatten_beam_dim( + lax.dynamic_slice(state.running_sequences, (0, 0, state.cur_len - 1), (batch_size, num_beams, 1)) + ) + model_outputs = model(input_token, params=params, **state.model_kwargs) + logits = unflatten_beam_dim(model_outputs.logits[:, 0], batch_size, num_beams) + cache = jax.tree_map( + lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values + ) + + # 2. Compute log probs + # get log probabilities from logits, + # process logits with processors (*e.g.* min_length, ...), and + # add new logprobs to existing running logprobs scores. + log_probs = jax.nn.log_softmax(logits) + log_probs = logits_processor( + flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len + ) + log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) + log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) + vocab_size = log_probs.shape[2] + log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) + + # 3. Retrieve top-K + # Each item in batch has num_beams * vocab_size candidate sequences. + # For each item, get the top 2*k candidates with the highest log- + # probabilities. We gather the top 2*K beams here so that even if the best + # K sequences reach EOS simultaneously, we have another K sequences + # remaining to continue the live beam search. + # Gather the top 2*K scores from _all_ beams. + # Gather 2*k top beams. + # Recover the beam index by floor division. + # Recover token id by modulo division and expand Id array for broadcasting. + # Update sequences for the 2*K top-k new sequences. + beams_to_keep = 2 * num_beams + topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) + topk_beam_indices = topk_indices // vocab_size + topk_running_sequences = gather_beams( + state.running_sequences, topk_beam_indices, batch_size, beams_to_keep + ) + topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) + topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) + + # 4. Check which sequences have ended + # Update current sequences: + # Did any of these sequences reach an end marker? + # To prevent these just finished sequences from being added to the current sequences + # set of active beam search sequences, set their log probs to a very large + # negative value. + did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id + topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) + + # 5. Get running sequences scores for next + # Determine the top k beam indices (from top 2*k beams) from log probs + # and gather top k beams (from top 2*k beams). + next_topk_indices = jnp.flip(lax.top_k(topk_log_probs, k=num_beams)[1], axis=1) + next_running_sequences, next_running_scores = gather_beams( + [topk_sequences, topk_log_probs], next_topk_indices, batch_size, num_beams + ) + + # 6. Process topk logits + # Further process log probs: + # - add length penalty + # - make sure no scores can be added anymore if beam is full + # - make sure still running sequences cannot be chosen as finalized beam + topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty) + beams_in_batch_are_full = ( + jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) + & early_stopping + ) + add_penalty = ~did_topk_just_finished | beams_in_batch_are_full + topk_log_probs += add_penalty * np.array(-1.0e7) + + # 7. Get scores, sequences, is sentence finished for next. + # Combine sequences, scores, and flags along the beam dimension and compare + # new finished sequence scores to existing finished scores and select the + # best from the new set of beams + merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) + merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) + merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) + topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) + next_sequences, next_scores, next_is_sent_finished = gather_beams( + [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams + ) + + # 8. Update model kwargs. + # Determine the top k beam indices from the original set of all beams. + # With these, gather the top k beam-associated caches. + next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) + next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) + model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache) + next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) + + return BeamSearchState( + cur_len=state.cur_len + 1, + running_scores=next_running_scores, + running_sequences=next_running_sequences, + scores=next_scores, + sequences=next_sequences, + is_sent_finished=next_is_sent_finished, + model_kwargs=next_model_kwargs, + ) + + # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU + state = beam_search_body_fn(state) + + if not trace: + state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state) + else: + state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) + + # Account for the edge-case where there are no finished sequences for a + # particular batch item. If so, return running sequences for that batch item. + none_finished = jnp.any(state.is_sent_finished, axis=1) + sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) + scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) + + # take best beam for each batch + sequences = sequences[:, -1] + scores = scores[:, -1] + + return FlaxBeamSearchOutput(sequences=sequences, scores=scores) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index f4cbcb24968..7bae4a9a763 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -2,6 +2,24 @@ from ..file_utils import requires_backends +class FlaxForcedBOSTokenLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxForcedEOSTokenLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @@ -25,6 +43,15 @@ class FlaxLogitsWarper: requires_backends(self, ["flax"]) +class FlaxMinLengthLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxTemperatureLogitsWarper: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) diff --git a/tests/test_generation_flax_logits_process.py b/tests/test_generation_flax_logits_process.py index 4dacb5dc0ad..dd74783a80e 100644 --- a/tests/test_generation_flax_logits_process.py +++ b/tests/test_generation_flax_logits_process.py @@ -28,7 +28,10 @@ if is_flax_available(): import jax import jax.numpy as jnp from transformers.generation_flax_logits_process import ( + FlaxForcedBOSTokenLogitsProcessor, + FlaxForcedEOSTokenLogitsProcessor, FlaxLogitsProcessorList, + FlaxMinLengthLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, @@ -57,8 +60,8 @@ class LogitsProcessorTest(unittest.TestCase): temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5) temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3) - warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1) - warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1) + warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy(), cur_len=None), axis=-1) + warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy(), cur_len=None), axis=-1) # uniform distribution stays uniform self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)) @@ -83,7 +86,7 @@ class LogitsProcessorTest(unittest.TestCase): top_k_warp = FlaxTopKLogitsWarper(3) - scores = top_k_warp(input_ids, ramp_logits) + scores = top_k_warp(input_ids, ramp_logits, cur_len=None) # check that correct tokens are filtered self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False]) @@ -94,7 +97,7 @@ class LogitsProcessorTest(unittest.TestCase): top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3) ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy() - scores = top_k_warp_safety_check(input_ids, ramp_logits) + scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None) # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2]) @@ -108,7 +111,7 @@ class LogitsProcessorTest(unittest.TestCase): dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]])) top_p_warp = FlaxTopPLogitsWarper(0.7) - filtered_dist = np.exp(top_p_warp(input_ids, dist)) + filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None)) # dist should be filtered to keep min num values so that sum is >= 0.7 # exp (-inf) => 0 @@ -125,15 +128,81 @@ class LogitsProcessorTest(unittest.TestCase): # make sure at least 2 tokens are kept top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) - filtered_dist = top_p_warp(input_ids, ramp_logits) + filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len=None) # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2]) + def test_min_length_dist_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + + min_dist_processor = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + + # check that min length is applied at length 5 + input_ids = ids_tensor((batch_size, 20), vocab_size=20) + cur_len = 5 + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len) + self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")]) + + # check that min length is not applied anymore at length 15 + scores = self._get_uniform_logits(batch_size, vocab_size) + cur_len = 15 + scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len) + self.assertFalse(jnp.isinf(scores_before_min_length).any()) + + def test_forced_bos_token_logits_processor(self): + vocab_size = 20 + batch_size = 4 + bos_token_id = 0 + + logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id) + + # check that all scores are -inf except the bos_token_id score + input_ids = ids_tensor((batch_size, 1), vocab_size=20) + cur_len = 1 + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores, cur_len=cur_len) + self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all()) + self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero + + # check that bos_token_id is not forced if current length is greater than 1 + cur_len = 3 + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores, cur_len=cur_len) + self.assertFalse(jnp.isinf(scores).any()) + + def test_forced_eos_token_logits_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + max_length = 5 + + logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) + + # check that all scores are -inf except the eos_token_id when max_length is reached + input_ids = ids_tensor((batch_size, 4), vocab_size=20) + cur_len = 4 + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores, cur_len=cur_len) + self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all()) + self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero + + # check that eos_token_id is not forced if max_length is not reached + cur_len = 3 + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores, cur_len=cur_len) + self.assertFalse(jnp.isinf(scores).any()) + def test_processor_list(self): batch_size = 4 sequence_length = 10 vocab_size = 15 + eos_token_id = 2 + bos_token_id = 1 + max_length = 15 # dummy input_ids and scores input_ids = ids_tensor((batch_size, sequence_length), vocab_size) @@ -147,14 +216,83 @@ class LogitsProcessorTest(unittest.TestCase): top_k_warp = FlaxTopKLogitsWarper(3) top_p_warp = FlaxTopPLogitsWarper(0.8) + # instantiate all logits processors + min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id) + eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) + + cur_len = 10 + # no processor list - scores = temp_dist_warp(input_ids, scores) - scores = top_k_warp(input_ids, scores) - scores = top_p_warp(input_ids, scores) + scores = temp_dist_warp(input_ids, scores, cur_len=cur_len) + scores = top_k_warp(input_ids, scores, cur_len=cur_len) + scores = top_p_warp(input_ids, scores, cur_len=cur_len) + scores = min_dist_proc(input_ids, scores, cur_len=cur_len) + scores = bos_dist_proc(input_ids, scores, cur_len=cur_len) + scores = eos_dist_proc(input_ids, scores, cur_len=cur_len) # with processor list - processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp]) - scores_comp = processor(input_ids, scores_comp) + processor = FlaxLogitsProcessorList( + [temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc] + ) + scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) + + # scores should be equal + self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3)) + + # input_ids should never be changed + self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist()) + + def test_processor_list_jitted(self): + batch_size = 4 + sequence_length = 10 + vocab_size = 15 + eos_token_id = 2 + bos_token_id = 1 + max_length = 15 + + # dummy input_ids and scores + input_ids = ids_tensor((batch_size, sequence_length), vocab_size) + input_ids_comp = input_ids.copy() + + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_comp = scores.copy() + + # instantiate all dist processors + temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5) + top_k_warp = FlaxTopKLogitsWarper(3) + top_p_warp = FlaxTopPLogitsWarper(0.8) + + # instantiate all logits processors + min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id) + eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) + + cur_len = 10 + + # no processor list + def run_no_processor_list(input_ids, scores, cur_len): + scores = temp_dist_warp(input_ids, scores, cur_len=cur_len) + scores = top_k_warp(input_ids, scores, cur_len=cur_len) + scores = top_p_warp(input_ids, scores, cur_len=cur_len) + scores = min_dist_proc(input_ids, scores, cur_len=cur_len) + scores = bos_dist_proc(input_ids, scores, cur_len=cur_len) + scores = eos_dist_proc(input_ids, scores, cur_len=cur_len) + return scores + + # with processor list + def run_processor_list(input_ids, scores, cur_len): + processor = FlaxLogitsProcessorList( + [temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc] + ) + scores = processor(input_ids, scores, cur_len=cur_len) + return scores + + jitted_run_no_processor_list = jax.jit(run_no_processor_list) + jitted_run_processor_list = jax.jit(run_processor_list) + + scores = jitted_run_no_processor_list(input_ids, scores, cur_len) + scores_comp = jitted_run_processor_list(input_ids, scores_comp, cur_len) # scores should be equal self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3)) diff --git a/tests/test_generation_flax_utils.py b/tests/test_generation_flax_utils.py index 9b3e529c185..b5e0f086095 100644 --- a/tests/test_generation_flax_utils.py +++ b/tests/test_generation_flax_utils.py @@ -110,6 +110,23 @@ class FlaxGenerationTesterMixin: self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + def test_beam_search_generate(self): + config, input_ids, _, max_length = self._get_input_ids_and_config() + config.do_sample = False + config.max_length = max_length + config.num_beams = 2 + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + def test_sample_generate_logits_warper(self): config, input_ids, _, max_length = self._get_input_ids_and_config() config.do_sample = True @@ -117,6 +134,46 @@ class FlaxGenerationTesterMixin: config.temperature = 0.8 config.top_k = 10 config.top_p = 0.3 + config.min_length = 1 + config.forced_bos_token_id = 8 + config.forced_eos_token_id = 9 + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + def test_greedy_generate_logits_warper(self): + config, input_ids, _, max_length = self._get_input_ids_and_config() + config.max_length = max_length + config.min_length = 1 + config.forced_bos_token_id = 8 + config.forced_eos_token_id = 9 + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + def test_beam_search_generate_logits_warper(self): + config, input_ids, _, max_length = self._get_input_ids_and_config() + config.max_length = max_length + config.num_beams = 2 + config.min_length = 1 + config.forced_bos_token_id = 8 + config.forced_eos_token_id = 9 for model_class in self.all_generative_model_classes: model = model_class(config) @@ -168,3 +225,23 @@ class FlaxGenerationTesterMixin: jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + def test_beam_search_generate_attn_mask(self): + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # pad attention mask on the left + attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0) + + config.num_beams = 2 + config.max_length = max_length + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) diff --git a/tests/test_modeling_flax_bart.py b/tests/test_modeling_flax_bart.py index f446c4556f2..29981388f58 100644 --- a/tests/test_modeling_flax_bart.py +++ b/tests/test_modeling_flax_bart.py @@ -34,6 +34,7 @@ if is_flax_available(): import jax import jax.numpy as jnp + from transformers import BartTokenizer from transformers.models.bart.modeling_flax_bart import ( FlaxBartForConditionalGeneration, FlaxBartForQuestionAnswering, @@ -415,3 +416,62 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT input_ids = np.ones((1, 1)) * model.config.eos_token_id outputs = model(input_ids) self.assertIsNotNone(outputs) + + @slow + def test_summarization_fast(self): + model = FlaxBartForConditionalGeneration.from_pretrained("sshleifer/distilbart-cnn-6-6") + tokenizer = BartTokenizer.from_pretrained("sshleifer/distilbart-cnn-6-6") + + input_str = "This sentence is made of three parts. Each part is important on its own. One part is about animals, the other part about planes, and the last part about housing." + + input_ids = tokenizer(input_str, return_tensors="np").input_ids + sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences + + output_str = tokenizer.batch_decode(sequences)[0] + + assert ( + output_str == "This sentence is made of three parts. One part is about animals, the other part" + ) + + @slow + def test_cnn_summarization_same_as_fairseq(self): + model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + + FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noq + + SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.' + + # The below article tests that we don't add any hypotheses outside of the top n_beams + IRAN_ARTICLE = " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions." + + ARTICLE_SUBWAY = ' New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.' + + dct = tokenizer.batch_encode_plus( + [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY], + max_length=1024, + padding="max_length", + truncation_strategy="only_first", + truncation=True, + return_tensors="jax", + ) + + self.assertEqual(1024, dct["input_ids"].shape[1]) + hypotheses_batch = model.generate( + input_ids=dct["input_ids"], + attention_mask=dct["attention_mask"], + num_beams=2, + ).sequences + assert (hypotheses_batch[:, 1] == 0).all().item() + + EXPECTED = [ + "A French prosecutor says he is not aware of any video footage from on board the plane. Two German magazines claim to have found a cell phone video showing the crash. The publications say they watched the video, which was found by a source close to the investigation. All 150 on board the Germanwings flight were killed.", + "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move toward greater justice.", + "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The debate that has already begun will likely result in more heat than light. Bergen: The most misleading assertion is that the negotiations' objective at the outset was the total elimination of any nuclear program.", + "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the Bronx on Friday. If convicted, Barrientos faces up to four years in prison.", + ] + + generated_summaries = tokenizer.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) + assert generated_summaries == EXPECTED