[Flax generate] Add params to generate (#12171)

* fix_torch_device_generate_test

* remove @

* add params as input

* finish
This commit is contained in:
Patrick von Platen 2021-06-15 11:50:12 +01:00 committed by GitHub
parent a55dc157e3
commit 9bc9e59869
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -124,6 +124,7 @@ class FlaxGenerationMixin:
top_p: Optional[float] = None,
temperature: Optional[float] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
**model_kwargs,
):
r"""
@ -163,6 +164,8 @@ class FlaxGenerationMixin:
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime.
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
@ -211,12 +214,19 @@ class FlaxGenerationMixin:
eos_token_id,
prng_key,
logits_warper=logits_warper,
model_kwargs=model_kwargs,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
else:
return self._greedy_search(
input_ids, max_length, pad_token_id, eos_token_id, trace=trace, model_kwargs=model_kwargs
input_ids,
max_length,
pad_token_id,
eos_token_id,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
def _get_logits_warper(
@ -252,6 +262,7 @@ class FlaxGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
):
# init values
@ -296,7 +307,7 @@ class FlaxGenerationMixin:
def greedy_search_body_fn(state):
"""state update fn."""
model_outputs = model(state.current_token, **state.model_kwargs)
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
@ -331,9 +342,10 @@ class FlaxGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
):
# init values
max_length = max_length if max_length is not None else self.config.max_length
@ -381,7 +393,7 @@ class FlaxGenerationMixin:
def sample_search_body_fn(state):
"""state update fn."""
prng_key, prng_key_next = jax.random.split(state.prng_key)
model_outputs = model(state.current_token, **state.model_kwargs)
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
logits = model_outputs.logits[:, -1]