mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
unfreeze initial cache in gpt models (#14535)
This commit is contained in:
parent
2318bf77eb
commit
69511cdcae
@ -444,7 +444,7 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return init_variables["cache"]
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
|
@ -388,7 +388,7 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return init_variables["cache"]
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
|
Loading…
Reference in New Issue
Block a user