diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 6d9205c..fa155e2 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -14,7 +14,7 @@ concurrency: jobs: unittest-linux: name: unittest-linux - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main strategy: fail-fast: false with: @@ -26,4 +26,11 @@ jobs: set -ex cmake -DCMAKE_BUILD_TYPE=Debug test -Bbuild/test cmake --build build/test -j9 --config Debug - cd build/test && ctest + pushd build/test && ctest && popd + + # Install tokenizers + pip install . -v + pip install pytest blobfile + + # Run tests + pytest diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 2e77029..a466221 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -33,4 +33,11 @@ jobs: set -ex cmake -DCMAKE_BUILD_TYPE=Debug test -Bbuild/test cmake --build build/test -j9 --config Debug - cd build/test && ctest + pushd build/test && ctest && popd + + # Install tokenizers + ${CONDA_RUN} pip install . -v + ${CONDA_RUN} pip install pytest blobfile + + # Run tests + ${CONDA_RUN} pytest diff --git a/CMakeLists.txt b/CMakeLists.txt index 629f995..501305e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,11 +13,12 @@ # cmake_minimum_required(VERSION 3.18) set(CMAKE_CXX_STANDARD 17) - +set(CMAKE_POLICY_VERSION_MINIMUM 3.5) project(Tokenizers) option(TOKENIZERS_BUILD_TEST "Build tests" OFF) option(TOKENIZERS_BUILD_TOOLS "Build tools" OFF) +option(TOKENIZERS_BUILD_PYTHON "Build Python bindings" OFF) option(SUPPORT_REGEX_LOOKAHEAD "Support regex lookahead patterns (requires PCRE2)" OFF ) @@ -122,17 +123,49 @@ if(TOKENIZERS_BUILD_TOOLS) add_subdirectory(examples/tokenize_tool) endif() +# Build Python bindings +if(TOKENIZERS_BUILD_PYTHON) + include(FetchContent) + FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.13.6 + ) + FetchContent_MakeAvailable(pybind11) + + # Create the Python extension module + pybind11_add_module(pytorch_tokenizers_cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/python_bindings.cpp + ) + + # Link with the tokenizers library + target_link_libraries(pytorch_tokenizers_cpp PRIVATE tokenizers) + + # Set properties for the Python extension + target_compile_definitions(pytorch_tokenizers_cpp PRIVATE VERSION_INFO=${PROJECT_VERSION}) + + # Set the output name and let setuptools control the output directory + set_target_properties(pytorch_tokenizers_cpp PROPERTIES + OUTPUT_NAME "pytorch_tokenizers_cpp" + ) + + # Don't install the Python extension here - let setuptools handle it + # The setup.py will copy the built extension to the appropriate location +endif() + # Installation rules include(GNUInstallDirs) -# Install the library and its dependencies -install( - TARGETS tokenizers re2 sentencepiece-static - EXPORT tokenizers-targets - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} -) +if(NOT TOKENIZERS_BUILD_PYTHON) + # Install the library and its dependencies + install( + TARGETS tokenizers re2 sentencepiece-static + EXPORT tokenizers-targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) +endif() # Install header files install( diff --git a/include/pytorch/tokenizers/tiktoken.h b/include/pytorch/tokenizers/tiktoken.h index 15c0765..6cf9edc 100644 --- a/include/pytorch/tokenizers/tiktoken.h +++ b/include/pytorch/tokenizers/tiktoken.h @@ -46,6 +46,27 @@ class Tiktoken : public detail::BPETokenizerBase { } } + explicit Tiktoken( + std::string pattern, + const std::vector& special_tokens, + size_t bos_token_index, + size_t eos_token_index) + : Tiktoken( + pattern, + std::make_unique>(special_tokens), + bos_token_index, + eos_token_index) {} + + explicit Tiktoken( + const std::vector& special_tokens, + size_t bos_token_index, + size_t eos_token_index) + : Tiktoken( + _get_default_patern(), + std::make_unique>(special_tokens), + bos_token_index, + eos_token_index) {} + explicit Tiktoken( std::unique_ptr> special_tokens, size_t bos_token_index, diff --git a/pyproject.toml b/pyproject.toml index f09ba99..8176424 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,8 @@ requires = [ "pip>=23", # For building the pip package. "setuptools>=63", # For building the pip package contents. "wheel", # For building the pip package archive. + "pytest", # For running tests. + "pybind11", # For building the pybind11 C++ extension. ] build-backend = "setuptools.build_meta" @@ -64,12 +66,22 @@ Changelog = "https://github.com/pytorch/executorch/releases" [tool.setuptools.exclude-package-data] "*" = ["*.pyc"] -[tool.usort] -# Do not try to put "first-party" imports in their own section. -first_party_detection = false +[tool.pytest.ini_options] +testpaths = ["test"] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] [tool.black] -# Emit syntax compatible with older versions of python instead of only the range -# specified by `requires-python`. TODO: Remove this once we support these older -# versions of python and can expand the `requires-python` range. -target-version = ["py38", "py39", "py310", "py311", "py312"] +target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | build + | dist + | third-party +)/ +''' diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..6280e51 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,54 @@ +[pytest] +# Pytest configuration for PyTorch Tokenizers + +# Test discovery +testpaths = test +python_files = test_*.py *_test.py +python_classes = Test* +python_functions = test_* + +# Output options with explicit ignores +addopts = + # show summary of all tests that did not pass + -rEfX + # Make tracebacks shorter + --tb=native + # capture only Python print and C++ py::print, but not C output (low-level Python errors) + --capture=sys + # don't suppress warnings, but don't shove them all to the end either + -p no:warnings + # Ignore backends/arm tests you need to run examples/arm/setup.sh to install some tool to make them work + # For GitHub testing this is setup/executed in the unittest-arm job see .github/workflows/pull.yml for more info. + --ignore=third-party + --ignore=build + --ignore=cmake + --ignore=examples + --ignore=pytorch_tokenizers.egg-info + +# Directories to ignore during test collection +norecursedirs = + build* + third-party* + cmake* + examples* + .git* + __pycache__* + *.egg-info* + *third-party* + +# Markers +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + integration: marks tests as integration tests + unit: marks tests as unit tests + +# Minimum version +minversion = 6.0 + +# Test timeout (in seconds) +timeout = 300 + +# Filter warnings +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning diff --git a/pytorch_tokenizers/TARGETS b/pytorch_tokenizers/TARGETS index d563279..36fc9ff 100644 --- a/pytorch_tokenizers/TARGETS +++ b/pytorch_tokenizers/TARGETS @@ -4,7 +4,7 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") load(":targets.bzl", "define_common_targets") -oncall("executorch") +oncall("ai_infra_mobile_platform") define_common_targets() diff --git a/pytorch_tokenizers/__init__.py b/pytorch_tokenizers/__init__.py index 441117a..d6f78c0 100644 --- a/pytorch_tokenizers/__init__.py +++ b/pytorch_tokenizers/__init__.py @@ -3,8 +3,14 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# @lint-ignore-every LICENSELINT +""" +PyTorch Tokenizers - Fast tokenizers for PyTorch + +This package provides Python bindings for fast C++ tokenizer implementations +including HuggingFace, TikToken, Llama2C, and SentencePiece tokenizers. +""" +# @lint-ignore-every LICENSELINT from typing import Optional @@ -12,7 +18,23 @@ from .llama2c import Llama2cTokenizer from .tiktoken import TiktokenTokenizer -__all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer"] +__version__ = "0.1.0" + +try: + from .pytorch_tokenizers_cpp import ( # @manual=//pytorch/tokenizers:pytorch_tokenizers_cpp + Error, + HFTokenizer as CppHFTokenizer, + Llama2cTokenizer as CppLlama2cTokenizer, + SPTokenizer as CppSPTokenizer, + Tiktoken as CppTiktoken, + TokenIndex, + Tokenizer, + ) +except ImportError as e: + raise ImportError( + f"Failed to import C++ tokenizer bindings: {e}. " + "Make sure the package was built correctly with pybind11." + ) from e def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None): @@ -25,3 +47,17 @@ def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = No print("Using Tiktokenizer") tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path)) return tokenizer + + +__all__ = [ + "CppHFTokenizer", + "CppLlama2cTokenizer", + "CppSPTokenizer", + "CppTiktoken", + "Error", + "HFTokenizer", + "Llama2cTokenizer", + "TiktokenTokenizer", + "TokenIndex", + "Tokenizer", +] diff --git a/pytorch_tokenizers/targets.bzl b/pytorch_tokenizers/targets.bzl index e68ea32..241e687 100644 --- a/pytorch_tokenizers/targets.bzl +++ b/pytorch_tokenizers/targets.bzl @@ -11,19 +11,17 @@ def define_common_targets(): srcs = [ "__init__.py", "constants.py", + "hf_tokenizer.py", "llama2c.py", "tiktoken.py", - "hf_tokenizer.py", ], base_module = "pytorch_tokenizers", visibility = ["PUBLIC"], _is_external_target = True, - external_deps = [ - "sentencepiece-py", - ], deps = [ - "fbsource//third-party/pypi/blobfile:blobfile", + "fbsource//third-party/pypi/sentencepiece:sentencepiece", "fbsource//third-party/pypi/tiktoken:tiktoken", "fbsource//third-party/pypi/tokenizers:tokenizers", + "//pytorch/tokenizers:pytorch_tokenizers_cpp", # @manual ], ) diff --git a/pytorch_tokenizers/tiktoken.py b/pytorch_tokenizers/tiktoken.py index 64e9264..2163213 100644 --- a/pytorch_tokenizers/tiktoken.py +++ b/pytorch_tokenizers/tiktoken.py @@ -12,7 +12,6 @@ AbstractSet, cast, Collection, - Dict, Iterator, List, Literal, diff --git a/setup.py b/setup.py index cfa4b2d..406fe9e 100644 --- a/setup.py +++ b/setup.py @@ -5,14 +5,156 @@ # LICENSE file in the root directory of this source tree. # @lint-ignore-every LICENSELINT # type: ignore[syntax] -from setuptools import find_packages, setup +import os +import re +import subprocess +import sys +from pathlib import Path + +from setuptools import Extension, find_packages, setup +from setuptools.command.build_ext import build_ext + +# Read the README file with open("README.md", "r") as f: long_description = f.read() + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def build_extension(self, ext): # noqa C901 + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + # Ensure the extension goes into the pytorch_tokenizers package directory + extdir = os.path.join(extdir, "pytorch_tokenizers") + + # Required for auto-detection & inclusion of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + # CMake lets you override the generator - we check this. + # Can be set with Conda-Build, for example. + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + "-DSUPPORT_REGEX_LOOKAHEAD=ON", + "-DTOKENIZERS_BUILD_PYTHON=ON", + "-DCMAKE_POSITION_INDEPENDENT_CODE=ON", + ] + build_args = ["--target", "pytorch_tokenizers_cpp"] + + # Adding CMake arguments set as environment variable + # (needed e.g. to build for ARM OSX on conda-forge) + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + if self.compiler.compiler_type != "msvc": + # Using Ninja-build since it a) is available as a wheel and b) + # multithreads automatically. MSVC would require all variables be + # exported for Ninja to pick it up, which is a little tricky to do. + # Users can override the generator with CMAKE_GENERATOR in CMake + # 3.15+. + if not cmake_generator or cmake_generator == "Ninja": + try: + import ninja # noqa: F401 + + ninja_executable_path = os.path.join(ninja.BIN_DIR, "ninja") + cmake_args += [ + "-GNinja", + f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + ] + except ImportError: + pass + + else: + # Single config generators are handled "normally" + single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) + + # CMake allows an arch-in-generator style for backward compatibility + contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) + + # Specify the arch if using MSVC generator, but only if it doesn't + # contain a backward-compatibility arch spec already in the + # generator name. + if not single_config and not contains_arch: + cmake_args += ["-A", "x64"] + + # Multi-config generators have a different way to specify configs + if not single_config: + cmake_args += [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}" + ] + build_args += ["--config", cfg] + + if sys.platform.startswith("darwin"): + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) + if archs: + cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + subprocess.run( + ["cmake", ext.sourcedir] + cmake_args, cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", "."] + build_args, cwd=build_temp, check=True + ) + + setup( + name="pytorch-tokenizers", version="0.1.0", long_description=long_description, long_description_content_type="text/markdown", + url="https://github.com/pytorch-labs/tokenizers", packages=find_packages(), + ext_modules=[CMakeExtension("pytorch_tokenizers_cpp")], + cmdclass={"build_ext": CMakeBuild}, + zip_safe=False, + python_requires=">=3.10", + install_requires=[ + "pybind11>=2.6.0", + ], + setup_requires=[ + "pybind11>=2.6.0", + "cmake>=3.18", + ], + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: C++", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], ) diff --git a/src/python_bindings.cpp b/src/python_bindings.cpp new file mode 100644 index 0000000..af57d3c --- /dev/null +++ b/src/python_bindings.cpp @@ -0,0 +1,256 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// @lint-ignore-every LICENSELINT + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace tokenizers; + +// Helper function to convert Result to Python +template +T unwrap_result(const Result& result) { + if (result.error() != Error::Ok) { + std::string error_msg; + switch (result.error()) { + case Error::Ok: + error_msg = "Ok"; + break; + case Error::Internal: + error_msg = "Internal"; + break; + case Error::Uninitialized: + error_msg = "Uninitialized"; + break; + case Error::OutOfRange: + error_msg = "OutOfRange"; + break; + case Error::LoadFailure: + error_msg = "LoadFailure"; + break; + case Error::EncodeFailure: + error_msg = "EncodeFailure"; + break; + case Error::Base64DecodeFailure: + error_msg = "Base64DecodeFailure"; + break; + case Error::ParseFailure: + error_msg = "ParseFailure"; + break; + case Error::DecodeFailure: + error_msg = "DecodeFailure"; + break; + case Error::RegexFailure: + error_msg = "RegexFailure"; + break; + default: + error_msg = "Unknown error"; + break; + } + throw std::runtime_error("Tokenizer error: " + error_msg); + } + return result.get(); +} + +PYBIND11_MODULE(pytorch_tokenizers_cpp, m) { + m.doc() = "PyTorch Tokenizers Python bindings"; + + // Bind Error enum + py::enum_(m, "Error") + .value("Ok", Error::Ok) + .value("Internal", Error::Internal) + .value("Uninitialized", Error::Uninitialized) + .value("OutOfRange", Error::OutOfRange) + .value("LoadFailure", Error::LoadFailure) + .value("EncodeFailure", Error::EncodeFailure) + .value("Base64DecodeFailure", Error::Base64DecodeFailure) + .value("ParseFailure", Error::ParseFailure) + .value("DecodeFailure", Error::DecodeFailure) + .value("RegexFailure", Error::RegexFailure); + + // Bind TokenIndex struct + py::class_(m, "TokenIndex") + .def_readonly("str", &TokenIndex::str) + .def_readonly("id", &TokenIndex::id); + + // Bind base Tokenizer class + py::class_(m, "Tokenizer") + .def( + "load", + [](Tokenizer& self, const std::string& tokenizer_path) { + Error error = self.load(tokenizer_path); + if (error != Error::Ok) { + throw std::runtime_error("Failed to load tokenizer"); + } + }, + py::arg("tokenizer_path")) + .def( + "encode", + [](const Tokenizer& self, + const std::string& input, + int8_t bos, + int8_t eos) { + return unwrap_result(self.encode(input, bos, eos)); + }, + py::arg("input"), + py::arg("bos") = 0, + py::arg("eos") = 0) + .def( + "decode", + [](const Tokenizer& self, uint64_t token) { + return unwrap_result(self.decode(token, token)); + }, + py::arg("token")) + .def("vocab_size", &Tokenizer::vocab_size) + .def("bos_tok", &Tokenizer::bos_tok) + .def("eos_tok", &Tokenizer::eos_tok) + .def("is_loaded", &Tokenizer::is_loaded); + + // Bind HFTokenizer + py::class_(m, "HFTokenizer") + .def(py::init<>()) + .def( + "load", + [](HFTokenizer& self, const std::string& tokenizer_path) { + Error error = self.load(tokenizer_path); + if (error != Error::Ok) { + throw std::runtime_error("Failed to load HF tokenizer"); + } + }, + py::arg("tokenizer_path")) + .def( + "encode", + [](const HFTokenizer& self, + const std::string& input, + int8_t bos, + int8_t eos) { + return unwrap_result(self.encode(input, bos, eos)); + }, + py::arg("input"), + py::arg("bos") = 0, + py::arg("eos") = 0) + .def( + "decode", + [](const HFTokenizer& self, uint64_t token) { + return unwrap_result(self.decode(token, token)); + }, + py::arg("token")); + + // Bind Tiktoken + py::class_(m, "Tiktoken") + .def(py::init<>()) + .def( + py::init, size_t, size_t>(), + py::arg("special_tokens"), + py::arg("bos_token_index"), + py::arg("eos_token_index")) + .def( + py::init, size_t, size_t>(), + py::arg("pattern"), + py::arg("special_tokens"), + py::arg("bos_token_index"), + py::arg("eos_token_index")) + .def( + "load", + [](Tiktoken& self, const std::string& tokenizer_path) { + Error error = self.load(tokenizer_path); + if (error != Error::Ok) { + throw std::runtime_error("Failed to load Tiktoken tokenizer"); + } + }, + py::arg("tokenizer_path")) + .def( + "encode", + [](const Tiktoken& self, + const std::string& input, + int8_t bos, + int8_t eos) { + return unwrap_result(self.encode(input, bos, eos)); + }, + py::arg("input"), + py::arg("bos") = 0, + py::arg("eos") = 0) + .def( + "decode", + [](const Tiktoken& self, uint64_t token) { + return unwrap_result(self.decode(token, token)); + }, + py::arg("token")); + + // Bind Llama2cTokenizer + py::class_(m, "Llama2cTokenizer") + .def(py::init<>()) + .def( + "load", + [](Llama2cTokenizer& self, const std::string& tokenizer_path) { + Error error = self.load(tokenizer_path); + if (error != Error::Ok) { + throw std::runtime_error("Failed to load Llama2c tokenizer"); + } + }, + py::arg("tokenizer_path")) + .def( + "encode", + [](const Llama2cTokenizer& self, + const std::string& input, + int8_t bos, + int8_t eos) { + return unwrap_result(self.encode(input, bos, eos)); + }, + py::arg("input"), + py::arg("bos") = 0, + py::arg("eos") = 0) + .def( + "decode", + [](const Llama2cTokenizer& self, uint64_t token) { + return unwrap_result(self.decode(token, token)); + }, + py::arg("token")); + + // Bind SPTokenizer (SentencePiece) + py::class_(m, "SPTokenizer") + .def(py::init<>()) + .def( + "load", + [](SPTokenizer& self, const std::string& tokenizer_path) { + Error error = self.load(tokenizer_path); + if (error != Error::Ok) { + throw std::runtime_error( + "Failed to load SentencePiece tokenizer"); + } + }, + py::arg("tokenizer_path")) + .def( + "encode", + [](const SPTokenizer& self, + const std::string& input, + int8_t bos, + int8_t eos) { + return unwrap_result(self.encode(input, bos, eos)); + }, + py::arg("input"), + py::arg("bos") = 0, + py::arg("eos") = 0) + .def( + "decode", + [](const SPTokenizer& self, uint64_t token) { + return unwrap_result(self.decode(token, token)); + }, + py::arg("token")); +} diff --git a/targets.bzl b/targets.bzl index 1f3e963..6b0547f 100644 --- a/targets.bzl +++ b/targets.bzl @@ -171,3 +171,24 @@ def define_common_targets(): ], platforms = PLATFORMS, ) + + runtime.cxx_python_extension( + name = "pytorch_tokenizers_cpp", + srcs = [ + "src/python_bindings.cpp", + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + "//pytorch/tokenizers/...", + ], + base_module = "pytorch_tokenizers", + deps = [ + ":hf_tokenizer", + ":llama2c_tokenizer", + ":sentencepiece", + ":tiktoken", + ], + external_deps = [ + "pybind11", + ], + ) diff --git a/test/TARGETS b/test/TARGETS index 2341af9..6e2be8a 100644 --- a/test/TARGETS +++ b/test/TARGETS @@ -1,8 +1,34 @@ # Any targets that should be shared between fbcode and xplat must be defined in # targets.bzl. This file can contain fbcode-only targets. +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load(":targets.bzl", "define_common_targets") -oncall("executorch") +oncall("ai_infra_mobile_platform") define_common_targets() + +runtime.export_file( + name = "test_hf_tokenizer.json", + src = "resources/test_hf_tokenizer.json", + visibility = [ + "//pytorch/tokenizers/...", + "@EXECUTORCH_CLIENTS", + ], +) + +runtime.python_test( + name = "test_python_bindings", + srcs = [ + "test_python_bindings.py", + ], + preload_deps = [ + "//pytorch/tokenizers:regex_lookahead", + ], + resources = { + ":test_hf_tokenizer.json": "resources/test_hf_tokenizer.json", + }, + deps = [ + "//pytorch/tokenizers/pytorch_tokenizers:tokenizers", + ], +) diff --git a/test/targets.bzl b/test/targets.bzl index cc79100..aa7c479 100644 --- a/test/targets.bzl +++ b/test/targets.bzl @@ -127,9 +127,10 @@ def define_common_targets(): ], deps = [ "//pytorch/tokenizers/pytorch_tokenizers:tokenizers", + "fbsource//third-party/pypi/blobfile:blobfile", ], resources = { - ":test_tiktoken_tokenizer_model": "test_tiktoken_tokenizer.model", + ":test_tiktoken_tokenizer_model": "resources/test_tiktoken_tokenizer.model", }, ) diff --git a/test/test_python_bindings.py b/test/test_python_bindings.py new file mode 100644 index 0000000..1ba9ca5 --- /dev/null +++ b/test/test_python_bindings.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# @lint-ignore-every LICENSELINT + +""" +Test script for PyTorch Tokenizers Python bindings +""" + +import os +import sys +import unittest + +try: + import pytorch_tokenizers +except ImportError as e: + print(f"Failed to import pytorch_tokenizers: {e}") + print("Make sure the package is installed with: pip install .") + sys.exit(1) + + +class TestPythonBindings(unittest.TestCase): + """Test cases for Python bindings""" + + def test_import_success(self): + """Test that all classes can be imported successfully""" + # Test that all expected classes are available + self.assertTrue(hasattr(pytorch_tokenizers, "Error")) + self.assertTrue(hasattr(pytorch_tokenizers, "TokenIndex")) + self.assertTrue(hasattr(pytorch_tokenizers, "Tokenizer")) + self.assertTrue(hasattr(pytorch_tokenizers, "CppHFTokenizer")) + self.assertTrue(hasattr(pytorch_tokenizers, "CppTiktoken")) + self.assertTrue(hasattr(pytorch_tokenizers, "CppLlama2cTokenizer")) + self.assertTrue(hasattr(pytorch_tokenizers, "CppSPTokenizer")) + + def test_error_enum(self): + """Test Error enum values""" + self.assertTrue(hasattr(pytorch_tokenizers.Error, "Ok")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "Internal")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "Uninitialized")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "OutOfRange")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "LoadFailure")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "EncodeFailure")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "Base64DecodeFailure")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "ParseFailure")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "DecodeFailure")) + self.assertTrue(hasattr(pytorch_tokenizers.Error, "RegexFailure")) + + def test_tokenizer_creation(self): + """Test that tokenizers can be created""" + # Test HFTokenizer creation + hf_tokenizer = pytorch_tokenizers.CppHFTokenizer() + self.assertIsInstance(hf_tokenizer, pytorch_tokenizers.CppHFTokenizer) + self.assertFalse(hf_tokenizer.is_loaded()) + + # Test Tiktoken creation + tiktoken_tokenizer = pytorch_tokenizers.CppTiktoken() + self.assertIsInstance(tiktoken_tokenizer, pytorch_tokenizers.CppTiktoken) + self.assertFalse(tiktoken_tokenizer.is_loaded()) + + # Test Llama2cTokenizer creation + llama2c_tokenizer = pytorch_tokenizers.CppLlama2cTokenizer() + self.assertIsInstance(llama2c_tokenizer, pytorch_tokenizers.CppLlama2cTokenizer) + self.assertFalse(llama2c_tokenizer.is_loaded()) + + # Test SPTokenizer creation + sp_tokenizer = pytorch_tokenizers.CppSPTokenizer() + self.assertIsInstance(sp_tokenizer, pytorch_tokenizers.CppSPTokenizer) + self.assertFalse(sp_tokenizer.is_loaded()) + + def test_tokenizer_methods(self): + """Test that tokenizer methods exist and behave correctly for unloaded tokenizers""" + hf_tokenizer = pytorch_tokenizers.CppHFTokenizer() + + # Test basic properties + self.assertEqual(hf_tokenizer.vocab_size(), 0) + self.assertEqual(hf_tokenizer.bos_tok(), 0) + self.assertEqual(hf_tokenizer.eos_tok(), 0) + self.assertFalse(hf_tokenizer.is_loaded()) + + # Test that encode fails with unloaded tokenizer + with self.assertRaises(RuntimeError): + hf_tokenizer.encode("Hello world", 1, 1) + + # Test that decode fails with unloaded tokenizer + with self.assertRaises(RuntimeError): + hf_tokenizer.decode(1) + + def test_version(self): + """Test that version is available""" + self.assertTrue(hasattr(pytorch_tokenizers, "__version__")) + self.assertEqual(pytorch_tokenizers.__version__, "0.1.0") + + def test_hf_tokenizer_encode_decode(self): + """Test HFTokenizer with test_hf_tokenizer.json to encode/decode 'Hello world!'""" + # Get the path to the test tokenizer file + tokenizer_path = os.path.join( + os.path.dirname(__file__), "resources/test_hf_tokenizer.json" + ) + print(tokenizer_path) + + # Create and load the tokenizer + hf_tokenizer = pytorch_tokenizers.CppHFTokenizer() + self.assertFalse(hf_tokenizer.is_loaded()) + + # Load the tokenizer from JSON file + hf_tokenizer.load(tokenizer_path) + self.assertTrue(hf_tokenizer.is_loaded()) + + # Test encoding "Hello world!" + text = "Hello world!" + encoded_tokens = hf_tokenizer.encode(text, 1, 0) # bos=1, eos=0 + self.assertIsInstance(encoded_tokens, list) + self.assertGreater(len(encoded_tokens), 0) + + # Test decoding the encoded tokens + for token_id in encoded_tokens: + decoded_text = hf_tokenizer.decode(token_id) + self.assertIsInstance(decoded_text, str) + + # Test that we can get vocab size + vocab_size = hf_tokenizer.vocab_size() + self.assertGreater(vocab_size, 0) + + # Test BOS and EOS tokens + bos_token = hf_tokenizer.bos_tok() + eos_token = hf_tokenizer.eos_tok() + self.assertIsInstance(bos_token, int) + self.assertIsInstance(eos_token, int) diff --git a/test/test_tiktoken.py b/test/test_tiktoken.py index d3e9489..045ef84 100644 --- a/test/test_tiktoken.py +++ b/test/test_tiktoken.py @@ -5,27 +5,28 @@ # LICENSE file in the root directory of this source tree. # @lint-ignore-every LICENSELINT +import os import unittest -import pkg_resources - from pytorch_tokenizers.tiktoken import TiktokenTokenizer class TestTiktokenTokenizer(unittest.TestCase): def test_default(self): - model_path = pkg_resources.resource_filename( - "pytorch.tokenizers.test", "test_tiktoken_tokenizer.model" + model_path = os.path.join( + os.path.dirname(__file__), "resources/test_tiktoken_tokenizer.model" ) + tiktoken = TiktokenTokenizer(model_path) s = "<|begin_of_text|> hellow world." self.assertEqual(s, tiktoken.decode(tiktoken.encode(s, bos=False, eos=False))) def test_custom_pattern_and_special_tokens(self): o220k_pattern = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" - model_path = pkg_resources.resource_filename( - "pytorch.tokenizers.test", "test_tiktoken_tokenizer.model" + model_path = os.path.join( + os.path.dirname(__file__), "resources/test_tiktoken_tokenizer.model" ) + tiktoken = TiktokenTokenizer( model_path, pat_str=o220k_pattern,