diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 24bdf4faa06..5b9f38e1bc2 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -78,6 +78,9 @@ def convert_tf_weight_name_to_pt_weight_name( tf_name = tf_name[len(name_scope) :] tf_name = tf_name.lstrip("/") tf_name = tf_name.replace(":0", "") # device ids + if (len(tf_name) > 2048 and "___" in tf_name) or tf_name.count("___") > 10: + # ReDOS check + raise ValueError("TF variable name is too long or contains too many ___ separators: " + tf_name) tf_name = re.sub( r"/[^/]*___([^/]*)/", r"/\1/", tf_name ) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)