Skip to content

Use internal result as regex return type #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/pytorch/tokenizers/pcre2_regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@ class Pcre2Regex : public IRegex {
pcre2_match_data* match_data_;
bool is_valid_;

friend std::unique_ptr<IRegex> createRegex(const std::string& pattern);
};
friend tokenizers::Result<std::unique_ptr<IRegex>> createRegex(
const std::string& pattern);
};
3 changes: 2 additions & 1 deletion include/pytorch/tokenizers/re2_regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ class Re2Regex : public IRegex {
private:
std::unique_ptr<re2::RE2> regex_;

friend std::unique_ptr<IRegex> createRegex(const std::string& pattern);
friend tokenizers::Result<std::unique_ptr<IRegex>> createRegex(
const std::string& pattern);
};
5 changes: 4 additions & 1 deletion include/pytorch/tokenizers/regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <string>
#include <vector>

#include <pytorch/tokenizers/result.h>

struct Match {
std::string text;
size_t position;
Expand Down Expand Up @@ -38,4 +40,5 @@ class IRegex {
* @param pattern The regex pattern to compile.
* @return A unique pointer to an IRegex-compatible object.
*/
std::unique_ptr<IRegex> createRegex(const std::string& pattern);
tokenizers::Result<std::unique_ptr<IRegex>> createRegex(
const std::string& pattern);
17 changes: 17 additions & 0 deletions include/pytorch/tokenizers/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,23 @@ T* Result<T>::operator->() {

} // namespace tokenizers

/**
* Unwraps a Result<T> value, throwing a runtime_error if the result contains an
* error.
*
* @param[in] result__ The Result<T> to unwrap
*/
#define TK_UNWRAP_THROW(result__) \
({ \
auto unwrap_result__ = (result__); \
if (!unwrap_result__.ok()) { \
throw std::runtime_error( \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ET code throws exceptions afaik, will we guard it from tokenizers now?

"Error: " + \
std::to_string(static_cast<int>(unwrap_result__.error()))); \
} \
std::move(unwrap_result__.get()); \
})

/**
* Unwrap a Result to obtain its value. If the Result contains an error,
* propogate the error via trivial function return.
Expand Down
2 changes: 1 addition & 1 deletion src/pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) {
std::unique_ptr<IRegex> RegexPreTokenizer::create_regex_(
const std::string& pattern) {
assert(!pattern.empty());
return createRegex(pattern);
return TK_UNWRAP_THROW(createRegex(pattern));
}

std::vector<std::string> RegexPreTokenizer::pre_tokenize(
Expand Down
15 changes: 8 additions & 7 deletions src/regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
* Falls back to PCRE2 if RE2 rejects the pattern, then to std::regex if
* PCRE2 fails.
*/
std::unique_ptr<IRegex> createRegex(const std::string& pattern) {
// Try RE2 first
tokenizers::Result<std::unique_ptr<IRegex>> createRegex(
const std::string& pattern) {
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");

if (re2->ok()) {
return re2;
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
}

const re2::RE2* raw = re2->rawRegex();
Expand All @@ -29,21 +29,22 @@ std::unique_ptr<IRegex> createRegex(const std::string& pattern) {
std::cout
<< "RE2 is unable to support things such as negative lookaheads in "
<< pattern << ", using PCRE2 instead." << std::endl;
return pcre2;
return static_cast<std::unique_ptr<IRegex>>(std::move(pcre2));
}

// If PCRE2 also fails, fall back to std::regex
try {
std::cout
<< "PCRE2 failed to compile pattern, falling back to std::regex.";
return std::make_unique<StdRegex>("(" + pattern + ")");
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
} catch (const std::regex_error& e) {
std::cerr << "std::regex failed: " << e.what() << std::endl;
return nullptr;
return tokenizers::Error::LoadFailure;
}
} else {
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl;
return nullptr;
return tokenizers::Error::LoadFailure;
}
}
12 changes: 7 additions & 5 deletions src/tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ using namespace detail;
// ------------------------------Util start------------------------------------
namespace {

static std::unique_ptr<IRegex> _create_regex(const std::string& pattern) {
static Result<std::unique_ptr<IRegex>> _create_regex(
const std::string& pattern) {
assert(!pattern.empty());
return createRegex(pattern);
}

static std::unique_ptr<IRegex> _build_special_token_regex(
static Result<std::unique_ptr<IRegex>> _build_special_token_regex(
const std::vector<std::pair<std::string, std::uint64_t>>& special_encoder) {
std::string special_pattern;
for (const auto& ele : special_encoder) {
Expand All @@ -56,7 +57,7 @@ static std::unique_ptr<IRegex> _build_special_token_regex(
special_pattern += re2::RE2::QuoteMeta(ele.first);
}
if (special_pattern.empty()) {
return nullptr;
return static_cast<std::unique_ptr<IRegex>>(nullptr);
}
return _create_regex(special_pattern);
}
Expand Down Expand Up @@ -152,8 +153,9 @@ Error Tiktoken::load(const std::string& path) {

special_token_map_.emplace(TokenMap(special_token_map));

_regex = _create_regex(_pattern);
special_token_regex_ = _build_special_token_regex(special_token_map);
_regex = TK_UNWRAP(_create_regex(_pattern));
special_token_regex_ =
TK_UNWRAP(_build_special_token_regex(special_token_map));

// initialize vocab_size, bos_tok, eos_tok
vocab_size_ = token_map_->size() + special_token_map_->size();
Expand Down
38 changes: 20 additions & 18 deletions test/test_regex.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include <gtest/gtest.h>

#include "pytorch/tokenizers/regex.h"
#include "pytorch/tokenizers/re2_regex.h"
#include "pytorch/tokenizers/pcre2_regex.h"
#include "pytorch/tokenizers/re2_regex.h"
#include "pytorch/tokenizers/regex.h"
#include "pytorch/tokenizers/result.h"

// Test basic functionality
TEST(RegexTest, BasicMatching) {
auto regex = createRegex("\\w+");
auto regex = TK_UNWRAP_THROW(createRegex("\\w+"));
ASSERT_TRUE(regex->ok());

std::string text = "Hello world";
Expand All @@ -24,9 +25,9 @@ TEST(RegexTest, Pcre2Specific) {
const std::string pattern = "(?<=@)\\w+";
Re2Regex re2_regex(pattern);
ASSERT_FALSE(re2_regex.ok());

// Now verify that the factory function fallsback on a PCRE2 regex
auto regex = createRegex(pattern);
auto regex = TK_UNWRAP_THROW(createRegex(pattern));
ASSERT_TRUE(regex->ok());

std::string text = "[email protected]";
Expand All @@ -40,20 +41,21 @@ TEST(RegexTest, Pcre2Specific) {
// This specific pattern is from the Qwen2.5 1.5B pretokenizer.
// https://huggingface.co/Qwen/Qwen2.5-1.5B/raw/main/tokenizer.json
TEST(RegexTest, ComplexPatternWithNegativeLookahead) {
const std::string complex_pattern = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

const std::string complex_pattern =
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

// First verify that RE2 cannot handle this pattern
Re2Regex re2_regex(complex_pattern);
ASSERT_FALSE(re2_regex.ok());

// Now verify that the factory function fallsback on a PCRE2 regex
auto regex = createRegex(complex_pattern);
auto regex = TK_UNWRAP_THROW(createRegex(complex_pattern));
ASSERT_TRUE(regex->ok());

// Test the pattern with some sample text
std::string text = "Hello's world\n test";
auto matches = regex->findAll(text);

// We expect to match:
// 1. "Hello" (word)
// 2. "'s" (contraction)
Expand All @@ -62,22 +64,22 @@ TEST(RegexTest, ComplexPatternWithNegativeLookahead) {
// 5. " " (whitespace)
// 6. " test" (word with leading space)
ASSERT_EQ(matches.size(), 6);

EXPECT_EQ(matches[0].text, "Hello");
EXPECT_EQ(matches[0].position, 0);

EXPECT_EQ(matches[1].text, "'s");
EXPECT_EQ(matches[1].position, 5);

EXPECT_EQ(matches[2].text, " world");
EXPECT_EQ(matches[2].position, 7);

EXPECT_EQ(matches[3].text, "\n");
EXPECT_EQ(matches[3].position, 13);

EXPECT_EQ(matches[4].text, " ");
EXPECT_EQ(matches[4].position, 14);

EXPECT_EQ(matches[5].text, " test");
EXPECT_EQ(matches[5].position, 15);
}
}