Skip to content

Improve BERT tokenization for accented characters and non-latin scripts #5740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 42 additions & 93 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@
#include <cstdio>
#include <cstring>
#include <ctime>
#include <cwctype>
#include <forward_list>
#include <fstream>
#include <functional>
#include <initializer_list>
#include <locale>
#include <map>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -8897,37 +8899,46 @@ struct llm_tokenizer_wpm {
}

std::vector<std::string> preprocess(const std::string & text) {
std::string ori_str = normalize(text);
uint64_t ori_size = ori_str.size();
// normalalization form D
std::vector<uint32_t> codepoints = codepoints_from_utf8(text);
std::vector<uint32_t> nfd_codepoints;
for (uint32_t code : codepoints) {
auto it = nfd_map.find(code);
if (it != nfd_map.end()) {
for (uint32_t c : it->second) {
nfd_codepoints.push_back(c);
}
} else {
nfd_codepoints.push_back(code);
}
}

// single punct / single symbol / single digit
// baseline: add whitespace on the left and right of punct and chinese characters
std::vector<std::string> words;
// strip accents, strip control, uniformize whitespace,
// to lowercase, pad chinese characters, pad punctuation
std::string new_str = "";
uint64_t i = 0;
while (i < ori_size) {
int utf_char_len = utf8_len(ori_str[i]);
if ((utf_char_len == 1) && ispunct(ori_str[i])) {
new_str += " ";
new_str += ori_str[i];
new_str += " ";
i += 1;
for (uint32_t code : nfd_codepoints) {
int type = codepoint_type(code);
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
continue;
}
else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
code = to_lower(code);
if (type == CODEPOINT_TYPE_WHITESPACE) {
code = ' ';
}
std::string s = codepoint_to_utf8(code);
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
new_str += " ";
new_str += ori_str.substr(i, 3);
new_str += s;
new_str += " ";
i += 3;
}
else {
new_str += ori_str[i];
i += 1;
} else {
new_str += s;
}
}

// split by whitespace
uint64_t l = 0;
uint64_t r = 0;
std::vector<std::string> words;
while (r < new_str.size()) {
// if is whitespace
if (isspace(new_str[r])) {
Expand All @@ -8945,47 +8956,20 @@ struct llm_tokenizer_wpm {
return words;
}

std::string normalize(const std::string & text) {
// TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
std::string text2 = strip_accents(text);
for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) {
char c = text2[i];
if (c >= 'A' && c <= 'Z') {
text2[i] = c - 'A' + 'a';
}
uint32_t to_lower(uint32_t code) {
#if defined(_WIN32)
if (code > 0xFFFF) {
return code;
}
return text2;
#endif
return std::tolower(wchar_t(code), std::locale("en_US.UTF-8"));
Copy link
Collaborator

@cebtenzzre cebtenzzre Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iamlemec This depends on the system having the en_US.UTF-8 locale available, which is not always true in practice: nomic-ai/gpt4all#2160

Even if std::locale("") is available and not C, it could be e.g. tr_TR.UTF-8, which has different rules about how to uppercase/lowercase i and I.

Copy link
Collaborator

@cebtenzzre cebtenzzre Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it's not safe to assume that lowercasing one letter at a time is the same as lowercasing the whole string. See the notes at cppreference and the special-cased entry for capital sigma in UnicodeData.

}

bool is_chinese_char(const std::string & str) {
int len = str.length();
unsigned int codepoint = 0;
int num_bytes = 0;
int i = 0;
unsigned char ch = static_cast<unsigned char>(str[i]);
if (ch <= 0x7f) {
codepoint = ch;
num_bytes = 1;
} else if ((ch >> 5) == 0x06) {
codepoint = ch & 0x1f;
num_bytes = 2;
} else if ((ch >> 4) == 0x0e) {
codepoint = ch & 0x0f;
num_bytes = 3;
} else if ((ch >> 3) == 0x1e) {
codepoint = ch & 0x07;
num_bytes = 4;
}
for (int j = 1; j < num_bytes; ++j) {
if (i + j >= len) {
return false; // incomplete UTF-8 character
}
unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
if ((next_ch >> 6) != 0x02) {
return false; // invalid trailing byte
}
codepoint = (codepoint << 6) | (next_ch & 0x3f);
}
bool is_ascii_punct(uint32_t code) {
return code < 256 && ispunct(code);
}

bool is_chinese_char(uint32_t codepoint) {
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
(codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
Expand All @@ -9001,41 +8985,6 @@ struct llm_tokenizer_wpm {
return false;
}

std::string strip_accents(const std::string & input_string) {
std::string resultString;
std::map<std::string, char> accent_map = {
{"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
{"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
{"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
{"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
{"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
{"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
{"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
{"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
{"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'},
};

for (size_t i = 0; i < input_string.length();) {
int len = utf8_len(input_string[i]);
std::string curChar = input_string.substr(i, len);
auto iter = accent_map.find(curChar);
if (iter != accent_map.end()) {
resultString += iter->second;
} else {
resultString += curChar;
}
i += len;
}

return resultString;
}

static size_t utf8_len(char src) {
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
return lookup[highbits];
}

const llama_vocab & vocab;
};

Expand Down
Loading