diff --git a/.gitmodules b/.gitmodules index 2fb2537..04dde04 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "third-party/json"] path = third-party/json url = https://github.com/nlohmann/json.git +[submodule "third-party/pcre2"] + path = third-party/pcre2 + url = https://github.com/PCRE2Project/pcre2.git diff --git a/CMakeLists.txt b/CMakeLists.txt index c5eac98..f0ce71c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,19 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/abseil-cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece) + +# Configure PCRE2 +set(PCRE2_BUILD_PCRE2_8 ON) +set(PCRE2_BUILD_PCRE2_16 OFF) +set(PCRE2_BUILD_PCRE2_32 OFF) +set(PCRE2_BUILD_TESTS OFF) +set(PCRE2_BUILD_PCRE2GREP OFF) +set(PCRE2_BUILD_PCRE2TEST OFF) +set(PCRE2_BUILD_PCRE2GPERF OFF) +set(PCRE2_BUILD_DOCS OFF) +set(PCRE2_BUILD_LIBPCRE2_PDB OFF) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2) + set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) file(GLOB tokenizers_source_files ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) @@ -45,9 +58,10 @@ target_include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece/src ${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2 ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include - ${CMAKE_CURRENT_SOURCE_DIR}/third-party/llama.cpp-unicode/include) + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/llama.cpp-unicode/include + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2/src) -target_link_libraries(tokenizers PUBLIC sentencepiece-static re2::re2) +target_link_libraries(tokenizers PUBLIC sentencepiece-static re2::re2 pcre2-8) # Build test if(TOKENIZERS_BUILD_TEST) @@ -77,7 +91,8 @@ if(TOKENIZERS_BUILD_TEST) ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece ${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2 - ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include) + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2/src) target_link_libraries(${test_name} gtest_main GTest::gmock tokenizers) add_test(${test_name} "${test_name}") set_tests_properties(${test_name} PROPERTIES ENVIRONMENT ${test_env}) diff --git a/include/pytorch/tokenizers/pcre2_regex.h b/include/pytorch/tokenizers/pcre2_regex.h new file mode 100644 index 0000000..c3b3287 --- /dev/null +++ b/include/pytorch/tokenizers/pcre2_regex.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +// Define PCRE2 code unit width before including pcre2.h +#define PCRE2_CODE_UNIT_WIDTH 8 +#include + +#include + +namespace tokenizers { + +/** + * @brief PCRE2-based implementation of IRegex. + */ +class Pcre2Regex : public IRegex { + public: + /** + * @brief Construct a PCRE2 regex with the given pattern. + * + * @param pattern The regex pattern to compile. + */ + explicit Pcre2Regex(const std::string& pattern); + + /** + * @brief Destructor to clean up PCRE2 resources. + */ + ~Pcre2Regex(); + + /** + * @brief Return all non-overlapping matches found in the input string. + */ + virtual std::vector find_all(const std::string& text) const override; + + private: + pcre2_code* regex_; + pcre2_match_data* match_data_; + + friend Result> create_regex( + const std::string& pattern); +}; + +} // namespace tokenizers diff --git a/src/pcre2_regex.cpp b/src/pcre2_regex.cpp new file mode 100644 index 0000000..f680b2a --- /dev/null +++ b/src/pcre2_regex.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +namespace tokenizers { + +Pcre2Regex::Pcre2Regex(const std::string& pattern) + : regex_(nullptr), match_data_(nullptr) { + int error_code; + PCRE2_SIZE error_offset; + + // Compile the pattern + regex_ = pcre2_compile( + reinterpret_cast(pattern.c_str()), + pattern.length(), + PCRE2_UCP | PCRE2_UTF, // Enable Unicode support and UTF-8 mode + &error_code, + &error_offset, + nullptr); + + if (regex_ == nullptr) { + PCRE2_UCHAR error_buffer[256]; + pcre2_get_error_message(error_code, error_buffer, sizeof(error_buffer)); + std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": " + << error_buffer << std::endl; + return; + } + + // Create match data + match_data_ = pcre2_match_data_create_from_pattern(regex_, nullptr); + if (match_data_ == nullptr) { + pcre2_code_free(regex_); + regex_ = nullptr; + std::cerr << "Failed to create PCRE2 match data" << std::endl; + return; + } +} + +Pcre2Regex::~Pcre2Regex() { + if (match_data_) { + pcre2_match_data_free(match_data_); + } + if (regex_) { + pcre2_code_free(regex_); + } +} + +std::vector Pcre2Regex::find_all(const std::string& text) const { + std::vector result; + + if (!regex_ || !match_data_) { + return result; + } + + PCRE2_SIZE* ovector; + PCRE2_SPTR subject = reinterpret_cast(text.c_str()); + PCRE2_SIZE subject_length = text.length(); + PCRE2_SIZE offset = 0; + + while (offset < subject_length) { + int rc = pcre2_match( + regex_, + subject, + subject_length, + offset, + 0, // Default options + match_data_, + nullptr); + + if (rc < 0) { + if (rc == PCRE2_ERROR_NOMATCH) { + break; // No more matches + } else { + // Error occurred + PCRE2_UCHAR error_buffer[256]; + pcre2_get_error_message(rc, error_buffer, sizeof(error_buffer)); + std::cerr << "PCRE2 matching error: " << error_buffer << std::endl; + break; + } + } + + ovector = pcre2_get_ovector_pointer(match_data_); + + // Add the match to the result + result.push_back({ovector[0], ovector[1]}); + + // Move to the next position after the match + offset = ovector[1]; + + // If the match was empty, move forward by one character to avoid infinite + // loop + if (ovector[0] == ovector[1]) { + offset++; + } + } + + return result; +} + +} // namespace tokenizers diff --git a/src/regex.cpp b/src/regex.cpp index 6b26b72..873b270 100644 --- a/src/regex.cpp +++ b/src/regex.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -18,8 +19,8 @@ namespace tokenizers { /** * @brief Factory function that creates a regex object using RE2 if possible. - * Falls back to std::regex if RE2 rejects the pattern with - * ErrorBadPerlOp. + * Falls back to PCRE2 if RE2 rejects the pattern, then to std::regex if + * PCRE2 fails. */ Result> create_regex(const std::string& pattern) { // Try RE2 first @@ -30,10 +31,20 @@ Result> create_regex(const std::string& pattern) { } if (re2->regex_->error_code() == re2::RE2::ErrorBadPerlOp) { - try { + // RE2 doesn't support some Perl features, try PCRE2 + auto pcre2 = std::make_unique("(" + pattern + ")"); + + if (pcre2->regex_ != nullptr && pcre2->match_data_ != nullptr) { std::cout << "RE2 is unable to support things such as negative lookaheads in " - << pattern << ", defaulting to std::regex."; + << pattern << ", using PCRE2 instead." << std::endl; + return static_cast>(std::move(pcre2)); + } + + // If PCRE2 also fails, fall back to std::regex + try { + std::cout + << "PCRE2 failed to compile pattern, falling back to std::regex."; auto std_regex = std::make_unique("(" + pattern + ")"); return static_cast>(std::move(std_regex)); } catch (const std::regex_error& e) { diff --git a/targets.bzl b/targets.bzl index 501b156..ae583d7 100644 --- a/targets.bzl +++ b/targets.bzl @@ -28,6 +28,9 @@ def define_common_targets(): srcs = ["src/regex.cpp"] + glob([ "src/*_regex.cpp", ]), + deps = [ + "fbsource//third-party/pcre2:pcre2-8", + ], exported_headers = subdir_glob([ ("include", "pytorch/tokenizers/regex.h"), ("include", "pytorch/tokenizers/*_regex.h"), diff --git a/test/test_regex.cpp b/test/test_regex.cpp new file mode 100644 index 0000000..fa396c1 --- /dev/null +++ b/test/test_regex.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "pytorch/tokenizers/pcre2_regex.h" +#include "pytorch/tokenizers/re2_regex.h" +#include "pytorch/tokenizers/regex.h" + +using namespace tokenizers; + +class RegexTest : public ::testing::Test {}; + +// Test basic functionality +TEST_F(RegexTest, BasicMatching) { + auto regex = TK_UNWRAP_THROW(create_regex("\\w+")); + + std::string text = "Hello world"; + auto matches = regex->find_all(text); + ASSERT_EQ(matches.size(), 2); + EXPECT_EQ(matches[0].start, 0); + EXPECT_EQ(matches[0].end, 5); + EXPECT_EQ( + text.substr(matches[0].start, matches[0].end - matches[0].start), + "Hello"); + EXPECT_EQ(matches[1].start, 6); + EXPECT_EQ(matches[1].end, 11); + EXPECT_EQ( + text.substr(matches[1].start, matches[1].end - matches[1].start), + "world"); +} + +// Test pattern that only PCRE2 supports (lookbehind) +TEST_F(RegexTest, Pcre2Specific) { + const std::string pattern = "(?<=@)\\w+"; + + // Verify that the factory function fallsback on a PCRE2 regex + auto regex = TK_UNWRAP_THROW(create_regex(pattern)); + EXPECT_NE(dynamic_cast(regex.get()), nullptr); + + std::string text = "user@example.com"; + auto matches = regex->find_all(text); + ASSERT_EQ(matches.size(), 1); + EXPECT_EQ(matches[0].start, 5); + EXPECT_EQ(matches[0].end, 12); + EXPECT_EQ( + text.substr(matches[0].start, matches[0].end - matches[0].start), + "example"); +} + +// Test complex pattern with negative lookahead that should fall back to PCRE2. +// This specific pattern is from the Qwen2.5 1.5B pretokenizer. +// https://huggingface.co/Qwen/Qwen2.5-1.5B/raw/main/tokenizer.json +TEST_F(RegexTest, ComplexPatternWithNegativeLookahead) { + const std::string complex_pattern = + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + // Now verify that the factory function fallsback on a PCRE2 regex + auto regex = TK_UNWRAP_THROW(create_regex(complex_pattern)); + EXPECT_NE(dynamic_cast(regex.get()), nullptr); + + // Test the pattern with some sample text + std::string text = "Hello's world\n test"; + auto matches = regex->find_all(text); + + // We expect to match: + // 1. "Hello" (word) + // 2. "'s" (contraction) + // 3. " world" (word with leading space) + // 4. "\n" (newline) + // 5. " " (whitespace) + // 6. " test" (word with leading space) + ASSERT_EQ(matches.size(), 6); + + EXPECT_EQ(matches[0].start, 0); + EXPECT_EQ(matches[0].end, 5); + EXPECT_EQ( + text.substr(matches[0].start, matches[0].end - matches[0].start), + "Hello"); + EXPECT_EQ(matches[1].start, 5); + EXPECT_EQ(matches[1].end, 7); + EXPECT_EQ( + text.substr(matches[1].start, matches[1].end - matches[1].start), "'s"); + EXPECT_EQ(matches[2].start, 7); + EXPECT_EQ(matches[2].end, 13); + EXPECT_EQ( + text.substr(matches[2].start, matches[2].end - matches[2].start), + " world"); + EXPECT_EQ(matches[3].start, 13); + EXPECT_EQ(matches[3].end, 14); + EXPECT_EQ( + text.substr(matches[3].start, matches[3].end - matches[3].start), "\n"); + EXPECT_EQ(matches[4].start, 14); + EXPECT_EQ(matches[4].end, 15); + EXPECT_EQ( + text.substr(matches[4].start, matches[4].end - matches[4].start), " "); + EXPECT_EQ(matches[5].start, 15); + EXPECT_EQ(matches[5].end, 20); + EXPECT_EQ( + text.substr(matches[5].start, matches[5].end - matches[5].start), + " test"); +} diff --git a/third-party/pcre2 b/third-party/pcre2 new file mode 160000 index 0000000..2e03e32 --- /dev/null +++ b/third-party/pcre2 @@ -0,0 +1 @@ +Subproject commit 2e03e323339ab692640626f02f8d8d6f95bff9c6