Protect TFGenerationMixin.seed_generator so it's not created at import (#18044)

This commit is contained in:
Matt 2022-07-06 16:36:28 +01:00 committed by GitHub
parent 360719a6a4
commit be79cd7d8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -346,7 +346,14 @@ class TFGenerationMixin:
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
"""
seed_generator = tf.random.Generator.from_non_deterministic_state()
_seed_generator = None
@property
def seed_generator(self):
if self._seed_generator is None:
self._seed_generator = tf.random.Generator.from_non_deterministic_state()
return self._seed_generator
supports_xla_generation = True
def prepare_inputs_for_generation(self, inputs, **kwargs):