Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ def __init__(self, vocab, ids_to_tokens, emoji):
r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
)
self.content_repatter6 = re.compile(
r"((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*億)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*万)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*千)*(0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*(千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+(\(税込\)|\(税抜\)|\+tax)*"
r"(?:\d,\d{3}|[\d億万千])*"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def __init__(self, vocab, ids_to_tokens, emoji):
r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
)
self.content_repatter6 = re.compile(
r"((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*億)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*万)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*千)*(0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*(千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+(\(税込\)|\(税抜\)|\+tax)*"
r"(?:\d,\d{3}|[\d億万千])*"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
Expand Down
39 changes: 15 additions & 24 deletions src/transformers/models/nougat/tokenization_nougat_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,17 @@ def normalize_list_like_lines(generation):
normalization adjusts the bullet point style and nesting levels based on the captured patterns.
"""

# This matches lines starting with - or *, not followed by - or * (lists)
# that are then numbered by digits \d or roman numerals (one or more)
# and then, optional additional numbering of this line is captured
# this is then fed to re.finditer.
pattern = r"(?:^)(-|\*)?(?!-|\*) ?((?:\d|[ixv])+ )?.+? (-|\*) (((?:\d|[ixv])+)\.(\d|[ixv]) )?.*(?:$)"

for match in reversed(list(re.finditer(pattern, generation, flags=re.I | re.M))):
start, stop = match.span()
delim = match.group(3) + " "
splits = match.group(0).split(delim)
lines = generation.split("\n")
output_lines = []
for line_no, line in enumerate(lines):
match = re.search(r". ([-*]) ", line)
if not match or line[0] not in ("-", "*"):
output_lines.append(line)
continue # Doesn't fit the pattern we want, no changes
delim = match.group(1) + " "
splits = line.split(delim)[1:]
replacement = ""

if match.group(1) is not None:
splits = splits[1:]
delim1 = match.group(1) + " "
else:
delim1 = ""
continue # Skip false positives

pre, post = generation[:start], generation[stop:]
delim1 = line[0] + " "

for i, item in enumerate(splits):
level = 0
Expand All @@ -144,15 +135,15 @@ def normalize_list_like_lines(generation):
level = potential_numeral.count(".")

replacement += (
("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or start == 0 else delim1) + item.strip()
("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or line_no == 0 else delim1) + item.strip()
)

if post == "":
post = "\n"
if line_no == len(lines) - 1: # If this is the last line in the generation
replacement += "\n" # Add an empty line to the end of the generation

generation = pre + replacement + post
output_lines.append(replacement)

return generation
return "\n".join(output_lines)


def find_next_punctuation(text: str, start_idx=0):
Expand Down