unfreeze initial cache in gpt models (#14535)

This commit is contained in:
Suraj Patil 2021-11-26 18:21:47 +05:30 committed by GitHub
parent 2318bf77eb
commit 69511cdcae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -444,7 +444,7 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return init_variables["cache"]
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
def __call__(

View File

@ -388,7 +388,7 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return init_variables["cache"]
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
def __call__(