mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax generate] Add params to generate (#12171)
* fix_torch_device_generate_test * remove @ * add params as input * finish
This commit is contained in:
parent
a55dc157e3
commit
9bc9e59869
@ -124,6 +124,7 @@ class FlaxGenerationMixin:
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
@ -163,6 +164,8 @@ class FlaxGenerationMixin:
|
||||
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
|
||||
a considerably slower runtime.
|
||||
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`):
|
||||
Optionally the model parameters can be passed. Can be useful for parallelized generation.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||
|
||||
@ -211,12 +214,19 @@ class FlaxGenerationMixin:
|
||||
eos_token_id,
|
||||
prng_key,
|
||||
logits_warper=logits_warper,
|
||||
model_kwargs=model_kwargs,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
return self._greedy_search(
|
||||
input_ids, max_length, pad_token_id, eos_token_id, trace=trace, model_kwargs=model_kwargs
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def _get_logits_warper(
|
||||
@ -252,6 +262,7 @@ class FlaxGenerationMixin:
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
):
|
||||
# init values
|
||||
@ -296,7 +307,7 @@ class FlaxGenerationMixin:
|
||||
|
||||
def greedy_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
model_outputs = model(state.current_token, **state.model_kwargs)
|
||||
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
|
||||
next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)
|
||||
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
@ -331,9 +342,10 @@ class FlaxGenerationMixin:
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
prng_key: Optional[jax_xla.DeviceArray] = None,
|
||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
):
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
@ -381,7 +393,7 @@ class FlaxGenerationMixin:
|
||||
def sample_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
||||
model_outputs = model(state.current_token, **state.model_kwargs)
|
||||
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
|
||||
|
||||
logits = model_outputs.logits[:, -1]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user