mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: sample generation compatible with XLA and dynamic batch sizes (#19773)
This commit is contained in:
parent
c186e816bd
commit
5ed9bd1896
@ -360,6 +360,7 @@ class TFGenerationMixin:
|
||||
|
||||
@property
|
||||
def seed_generator(self):
|
||||
warnings.warn("`seed_generator` is deprecated and will be removed in a future version.", UserWarning)
|
||||
if self._seed_generator is None:
|
||||
self._seed_generator = tf.random.Generator.from_non_deterministic_state()
|
||||
return self._seed_generator
|
||||
@ -1920,7 +1921,7 @@ class TFGenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Tuple[tf.Tensor, Dict[str, Any]]:
|
||||
expanded_return_idx = tf.reshape(
|
||||
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1,)
|
||||
tf.tile(tf.reshape(tf.range(tf.shape(input_ids)[0]), (-1, 1)), (1, expand_size)), (-1,)
|
||||
)
|
||||
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
|
||||
|
||||
@ -2624,7 +2625,7 @@ class TFGenerationMixin:
|
||||
if seed is not None:
|
||||
sample_seed = seed
|
||||
else:
|
||||
sample_seed = tf.cast(self.seed_generator.make_seeds(count=1)[:, 0], dtype=tf.int32)
|
||||
sample_seed = tf.experimental.numpy.random.randint(tf.int32.min, tf.int32.max, (2,), dtype=tf.int32)
|
||||
next_tokens = tf.squeeze(
|
||||
tf.random.stateless_categorical(
|
||||
logits=next_tokens_scores, num_samples=1, seed=sample_seed, dtype=tf.int32
|
||||
|
Loading…
Reference in New Issue
Block a user