correctly handle mt5 (#9879)

This commit is contained in:
Stas Bekman 2021-01-29 08:11:22 -08:00 committed by GitHub
parent 7eadfe166e
commit 6bf94bc0b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -563,7 +563,7 @@ def freeze_embeds(model):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type = model.config.model_type
if model_type == "t5":
if model_type in ["t5", "mt5"]:
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)