diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index d871c351016..628e34ffe11 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -249,83 +249,6 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer): return updated_node -def get_docstring_indent(docstring): - # Match the first line after the opening triple quotes - match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring) - if match: - # Return the indentation spaces captured - return len(match.group(1)) - return 0 - - -def is_full_docstring(original_docstring: str, new_docstring: str, original_level: int) -> bool: - """Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then - be merged with the existing old one. - """ - # libcst returns the docstrinbgs with literal `r"""` quotes in front - new_docstring = new_docstring.split('"""', 1)[1] - # The docstring contains Args definition, so it is self-contained - if re.search(r"\n\s*Args:\n", new_docstring): - return True - elif re.search(r"\n\s*Args:\n", original_docstring): - return False - # Check if the docstring contains args docstring (meaning it is self contained): - param_pattern = re.compile( - # |--- Group 1 ---|| Group 2 ||- Group 3 -||---------- Group 4 ----------| - rf"^\s{{0,{original_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{original_level}}}\w+\s*\().)*)", - re.DOTALL | re.MULTILINE, - ) - match_object = param_pattern.search(new_docstring) - if match_object is not None: - return True - # If it contains Returns, but starts with text indented with an additional 4 spaces before, it is self-contained - # (this is the scenario when using `@add_start_docstrings_to_model_forward`, but adding more args to docstring) - match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring) - if match_object is not None: - full_indent = match_object.group(1) - striped_doc = new_docstring.strip("\n") - if striped_doc.startswith(full_indent + " " * 4) or striped_doc.startswith(full_indent + "\t"): - return True - return False - - -def merge_docstrings(original_docstring, updated_docstring): - original_level = get_docstring_indent(original_docstring) - if not is_full_docstring(original_docstring, updated_docstring, original_level): - # Split the docstring at the example section, assuming `"""` is used to define the docstring - parts = original_docstring.split("```") - if "```" in updated_docstring and len(parts) > 1: - updated_docstring = updated_docstring.lstrip('r"') - new_parts = updated_docstring.split("```") - if len(new_parts) != 3: - raise ValueError("There should only be one example, and it should have opening and closing '```'") - parts[1] = new_parts[1] - updated_docstring = "".join( - [ - f"\n{original_level * ' '}```", - parts[1], - "```", - parts[2], - ] - ) - docstring_opening, original_start_docstring = parts[0].rstrip(" \n").split('"""')[:2] - new_start_docstring = new_parts[0].rstrip(" \n") - docstring_opening += '"""' - if new_start_docstring.startswith(original_start_docstring): - updated_docstring = new_start_docstring + "\n" + updated_docstring - elif original_start_docstring.endswith(new_start_docstring): - updated_docstring = original_start_docstring + "\n" + updated_docstring - else: - updated_docstring = original_start_docstring + "\n" + new_start_docstring + "\n" + updated_docstring - updated_docstring = docstring_opening + updated_docstring - elif updated_docstring not in original_docstring: - # add tabulation if we are at the lowest level. - if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring): - updated_docstring = updated_docstring.replace("\n ", "\n ") - updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n') - return updated_docstring - - class SuperTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,)