Fix exploitable regexes in Nougat and GPTSan/GPTJNeoXJapanese (#36121)

* Fix potential regex catastrophic backtracking in NougatTokenizerFast

The original regex pattern in tokenization_nougat_fast.py was vulnerable to
catastrophic backtracking due to greedy quantifiers and nested alternations.
This commit replaces it with a more efficient pattern that:

1. Uses explicit character classes instead of dot (.)
2. Handles whitespace more precisely
3. Avoids unnecessary backtracking
4. Supports both lowercase and uppercase roman numerals
5. Maintains the same functionality while being more robust

* Try another regex

* Trying deepseek's answer

* Start with a simplification

* Another simplification

* Just rewrite the whole function myself

* Fix gptneox and gptsan

* Simplify the regex even further

* Tighten up the price regex a little

* Add possessive version of the regex

* Fix regex

* Much cleaner regexes

---------

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Matt 2025-02-21 19:49:51 +00:00 committed by GitHub
parent 547911e727
commit 92c5ca9dd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 30 deletions

View File

@ -18,6 +18,7 @@ import collections
import json
import os
import re
import sys
from typing import List, Optional, Tuple, Union
import numpy as np
@ -407,9 +408,23 @@ class SubWordJapaneseTokenizer:
self.content_repatter5 = re.compile(
r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
)
self.content_repatter6 = re.compile(
r"((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*億)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*万)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*千)*(0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*(千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+(\(税込\)|\(税抜\)|\+tax)*"
)
# The original version of this regex displays catastrophic backtracking behaviour. We avoid this using
# possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly
# different regex that should generally have the same behaviour in most non-pathological cases.
if sys.version_info >= (3, 11):
self.content_repatter6 = re.compile(
r"(?:\d,\d{3}|[\d億])*+"
r"(?:\d,\d{3}|[\d万])*+"
r"(?:\d,\d{3}|[\d千])*+"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
else:
self.content_repatter6 = re.compile(
r"(?:\d,\d{3}|[\d億万千])*"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
self.content_trans1 = str.maketrans({k: "<BLOCK>" for k in keisen + blocks})

View File

@ -18,6 +18,7 @@ import collections
import json
import os
import re
import sys
from typing import Optional, Tuple
import numpy as np
@ -230,9 +231,23 @@ class SubWordJapaneseTokenizer:
self.content_repatter5 = re.compile(
r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
)
self.content_repatter6 = re.compile(
r"((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*億)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*万)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*千)*(0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*(千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+(\(税込\)|\(税抜\)|\+tax)*"
)
# The original version of this regex displays catastrophic backtracking behaviour. We avoid this using
# possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly
# different regex that should generally have the same behaviour in most non-pathological cases.
if sys.version_info >= (3, 11):
self.content_repatter6 = re.compile(
r"(?:\d,\d{3}|[\d億])*+"
r"(?:\d,\d{3}|[\d万])*+"
r"(?:\d,\d{3}|[\d千])*+"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
else:
self.content_repatter6 = re.compile(
r"(?:\d,\d{3}|[\d億万千])*"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
self.content_trans1 = str.maketrans({k: "<BLOCK>" for k in keisen + blocks})

View File

@ -113,26 +113,17 @@ def normalize_list_like_lines(generation):
normalization adjusts the bullet point style and nesting levels based on the captured patterns.
"""
# This matches lines starting with - or *, not followed by - or * (lists)
# that are then numbered by digits \d or roman numerals (one or more)
# and then, optional additional numbering of this line is captured
# this is then fed to re.finditer.
pattern = r"(?:^)(-|\*)?(?!-|\*) ?((?:\d|[ixv])+ )?.+? (-|\*) (((?:\d|[ixv])+)\.(\d|[ixv]) )?.*(?:$)"
for match in reversed(list(re.finditer(pattern, generation, flags=re.I | re.M))):
start, stop = match.span()
delim = match.group(3) + " "
splits = match.group(0).split(delim)
lines = generation.split("\n")
output_lines = []
for line_no, line in enumerate(lines):
match = re.search(r". ([-*]) ", line)
if not match or line[0] not in ("-", "*"):
output_lines.append(line)
continue # Doesn't fit the pattern we want, no changes
delim = match.group(1) + " "
splits = line.split(delim)[1:]
replacement = ""
if match.group(1) is not None:
splits = splits[1:]
delim1 = match.group(1) + " "
else:
delim1 = ""
continue # Skip false positives
pre, post = generation[:start], generation[stop:]
delim1 = line[0] + " "
for i, item in enumerate(splits):
level = 0
@ -144,15 +135,15 @@ def normalize_list_like_lines(generation):
level = potential_numeral.count(".")
replacement += (
("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or start == 0 else delim1) + item.strip()
("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or line_no == 0 else delim1) + item.strip()
)
if post == "":
post = "\n"
if line_no == len(lines) - 1: # If this is the last line in the generation
replacement += "\n" # Add an empty line to the end of the generation
generation = pre + replacement + post
output_lines.append(replacement)
return generation
return "\n".join(output_lines)
def find_next_punctuation(text: str, start_idx=0):