[TF] from_pt should respect authorized_unexpected_keys (#8056)

This commit is contained in:
Sam Shleifer 2020-10-26 13:53:27 -04:00 committed by GitHub
parent 7ff7c4934b
commit bc9332b545
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -208,6 +208,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if tf_model.authorized_missing_keys is not None:
for pat in tf_model.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if tf_model.authorized_unexpected_keys is not None:
for pat in tf_model.authorized_unexpected_keys:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(