diff --git a/scripts/validate_rst_title_capitalization.py b/scripts/validate_rst_title_capitalization.py index 44318cd797163..a9d4de8c4be0c 100755 --- a/scripts/validate_rst_title_capitalization.py +++ b/scripts/validate_rst_title_capitalization.py @@ -14,6 +14,7 @@ from __future__ import annotations import argparse +import os import re import sys from typing import TYPE_CHECKING @@ -31,6 +32,7 @@ "Excel", "JSON", "HTML", + "XML", "SAS", "SQL", "BigQuery", @@ -159,6 +161,10 @@ "Liveserve", "I", "VSCode", + "RangeIndex", + "SparseArray", + "SparseDtype", + "HTTP", } CAP_EXCEPTIONS_DICT = {word.lower(): word for word in CAPITALIZATION_EXCEPTIONS} @@ -244,6 +250,31 @@ def find_titles(rst_file: str) -> Iterable[tuple[str, int]]: previous_line = line_no_last_elem +def _collect_errors(filename: str) -> int: + """ + Helper method to collect the errors per file + + Parameters + ---------- + filename : str + A file to validate, provided from the main method + + Returns + ------- + int + Number of incorrect headings found. + """ + errors: int = 0 + for title, line_number in find_titles(filename): + if title != correct_title_capitalization(title): + print( + f"""{filename}:{line_number}:{err_msg} "{title}" to "{ + correct_title_capitalization(title)}" """ + ) + errors += 1 + return errors + + def main(source_paths: list[str]) -> int: """ The main method to print all headings with incorrect capitalization. @@ -251,7 +282,8 @@ def main(source_paths: list[str]) -> int: Parameters ---------- source_paths : str - List of directories to validate, provided through command line arguments. + List of directories or files to validate, + provided through command line arguments. Returns ------- @@ -261,14 +293,20 @@ def main(source_paths: list[str]) -> int: number_of_errors: int = 0 - for filename in source_paths: - for title, line_number in find_titles(filename): - if title != correct_title_capitalization(title): - print( - f"""{filename}:{line_number}:{err_msg} "{title}" to "{ - correct_title_capitalization(title)}" """ - ) - number_of_errors += 1 + for path in source_paths: + # If `sourc_paths` is a dir, walk it to find the files + if os.path.isdir(path): + dirs = os.walk(path) + for dir in dirs: + files = dir[2] + for filename in files: + if not filename.endswith(".png"): + number_of_errors += _collect_errors( + os.path.join(dir[0], filename) + ) + else: + for filename in source_paths: + number_of_errors += _collect_errors(filename) return number_of_errors