TF: sample generation compatible with XLA and dynamic batch sizes (#19773)

This commit is contained in:
Joao Gante 2022-10-20 19:01:22 +01:00 committed by GitHub
parent c186e816bd
commit 5ed9bd1896
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -360,6 +360,7 @@ class TFGenerationMixin:
@property
def seed_generator(self):
warnings.warn("`seed_generator` is deprecated and will be removed in a future version.", UserWarning)
if self._seed_generator is None:
self._seed_generator = tf.random.Generator.from_non_deterministic_state()
return self._seed_generator
@ -1920,7 +1921,7 @@ class TFGenerationMixin:
**model_kwargs,
) -> Tuple[tf.Tensor, Dict[str, Any]]:
expanded_return_idx = tf.reshape(
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1,)
tf.tile(tf.reshape(tf.range(tf.shape(input_ids)[0]), (-1, 1)), (1, expand_size)), (-1,)
)
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
@ -2624,7 +2625,7 @@ class TFGenerationMixin:
if seed is not None:
sample_seed = seed
else:
sample_seed = tf.cast(self.seed_generator.make_seeds(count=1)[:, 0], dtype=tf.int32)
sample_seed = tf.experimental.numpy.random.randint(tf.int32.min, tf.int32.max, (2,), dtype=tf.int32)
next_tokens = tf.squeeze(
tf.random.stateless_categorical(
logits=next_tokens_scores, num_samples=1, seed=sample_seed, dtype=tf.int32