Skip to content

Commit c173f9f

Browse files
authored
Merge pull request #1 from pytorch-labs/add_base64
Add base64.h
2 parents 52eb48b + cec7b69 commit c173f9f

File tree

5 files changed

+321
-7
lines changed

5 files changed

+321
-7
lines changed

.github/workflows/pull.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
with:
21-
runner: linux.4xlarge
22-
docker-image: executorch-ubuntu-22.04-clang12
21+
runner: linux.2xlarge
2322
submodules: 'true'
2423
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
2524
timeout: 90

CMakeLists.txt

+12-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ~~~
1515
# It should also be cmake-lint clean.
1616
#
17-
cmake_minimum_required(VERSION 3.24)
17+
cmake_minimum_required(VERSION 3.18)
1818
set(CMAKE_CXX_STANDARD 17)
1919

2020
project(Tokenizers)
@@ -38,12 +38,19 @@ target_link_libraries(tokenizers PUBLIC sentencepiece-static)
3838

3939
# Build test
4040
if(TOKENIZERS_BUILD_TEST)
41-
find_package(GTest REQUIRED)
41+
include(FetchContent)
42+
FetchContent_Declare(
43+
googletest
44+
# Specify the commit you depend on and update it regularly.
45+
URL https://github.com/google/googletest/archive/5376968f6948923e2411081fd9372e71a59d8e77.zip
46+
)
47+
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
48+
FetchContent_MakeAvailable(googletest)
49+
4250
set(ENV{RESOURCES_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/test/resources)
4351
add_executable(sentencepiece_test test/test_sentencepiece.cpp)
4452
target_include_directories(
4553
sentencepiece_test PUBLIC third-party/sentencepiece/src
46-
third-party/sentencepiece include)
47-
target_link_libraries(sentencepiece_test PUBLIC tokenizers GTest::GTest
48-
GTest::Main)
54+
third-party/sentencepiece include GTEST_INCLUDE_PATH)
55+
target_link_libraries(sentencepiece_test PUBLIC tokenizers gtest_main)
4956
endif()

include/base64.h

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
// @lint-ignore-every LICENSELINT
9+
/**************************************************************************
10+
Copyright (c) 2023 sewenew
11+
12+
Licensed under the Apache License, Version 2.0 (the "License");
13+
you may not use this file except in compliance with the License.
14+
You may obtain a copy of the License at
15+
16+
http://www.apache.org/licenses/LICENSE-2.0
17+
18+
Unless required by applicable law or agreed to in writing, software
19+
distributed under the License is distributed on an "AS IS" BASIS,
20+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21+
See the License for the specific language governing permissions and
22+
limitations under the License.
23+
*************************************************************************/
24+
25+
#pragma once
26+
27+
#include <cassert>
28+
#include <cstdint>
29+
#include <string>
30+
#include <string_view>
31+
32+
#include "result.h"
33+
34+
namespace base64 {
35+
36+
using tokenizers::Error;
37+
using tokenizers::Result;
38+
39+
Result<std::string> decode(const std::string_view &input);
40+
41+
namespace detail {
42+
43+
constexpr uint32_t DECODE_TABLE[] = {
44+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
45+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
46+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
47+
255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255,
48+
255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
49+
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
50+
25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33,
51+
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
52+
49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
53+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
54+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
55+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
56+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
57+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
58+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
59+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
60+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
61+
255};
62+
63+
inline Error validate(uint32_t v) {
64+
if (v == 255) {
65+
fprintf(stderr, "invalid char");
66+
return Error::Base64DecodeFailure;
67+
}
68+
return Error::Ok;
69+
}
70+
71+
inline Error decode(const std::string_view &input, std::string &output) {
72+
if (input.size() != 4) {
73+
fprintf(stderr, "input length must be 4, got %zu", input.size());
74+
return Error::Base64DecodeFailure;
75+
}
76+
77+
uint32_t val = 0;
78+
79+
uint8_t c = input[0];
80+
auto v = DECODE_TABLE[c];
81+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
82+
val = v;
83+
84+
c = input[1];
85+
v = DECODE_TABLE[c];
86+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
87+
val = (val << 6) | v;
88+
89+
c = input[2];
90+
v = DECODE_TABLE[c];
91+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
92+
val = (val << 6) | v;
93+
94+
c = input[3];
95+
v = DECODE_TABLE[c];
96+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
97+
val = (val << 6) | v;
98+
99+
output.push_back(static_cast<char>((val >> 16) & 0xFF));
100+
output.push_back(static_cast<char>((val >> 8) & 0xFF));
101+
output.push_back(static_cast<char>(val & 0xFF));
102+
return Error::Ok;
103+
}
104+
105+
inline Error decode_1_padding(const std::string_view &input,
106+
std::string &output) {
107+
if (input.size() != 3) {
108+
fprintf(stderr, "input length must be 3, got %zu", input.size());
109+
return Error::Base64DecodeFailure;
110+
}
111+
112+
uint32_t val = 0;
113+
114+
uint8_t c = input[0];
115+
auto v = DECODE_TABLE[c];
116+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
117+
val = v;
118+
119+
c = input[1];
120+
v = DECODE_TABLE[c];
121+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
122+
val = (val << 6) | v;
123+
124+
c = input[2];
125+
v = DECODE_TABLE[c];
126+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
127+
val = (val << 6) | v;
128+
129+
output.push_back(static_cast<char>((val >> 10) & 0xFF));
130+
output.push_back(static_cast<char>((val >> 2) & 0xFF));
131+
return Error::Ok;
132+
}
133+
134+
inline Error decode_2_padding(const std::string_view &input,
135+
std::string &output) {
136+
TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure);
137+
138+
uint32_t val = 0;
139+
140+
uint8_t c = input[0];
141+
auto v = DECODE_TABLE[c];
142+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
143+
val = v;
144+
145+
c = input[1];
146+
v = DECODE_TABLE[c];
147+
TK_CHECK_OK_OR_RETURN_ERROR(validate(v));
148+
val = (val << 6) | v;
149+
150+
output.push_back(static_cast<char>((val >> 4) & 0xFF));
151+
return Error::Ok;
152+
}
153+
154+
} // namespace detail
155+
156+
inline tokenizers::Result<std::string> decode(const std::string_view &input) {
157+
if (input.empty()) {
158+
fprintf(stderr, "empty input");
159+
return Error::Base64DecodeFailure;
160+
}
161+
162+
// Faster than `input.size() % 4`.
163+
if ((input.size() & 3) != 0 || input.size() < 4) {
164+
fprintf(stderr,
165+
"input length must be larger than 4 and is multiple of 4, got %zu",
166+
input.size());
167+
return Error::Base64DecodeFailure;
168+
}
169+
170+
std::string output;
171+
output.reserve(input.size() / 4 * 3);
172+
auto idx = 0U;
173+
for (; idx < input.size() - 4; idx += 4) {
174+
TK_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
175+
}
176+
177+
// Last 4 bytes. Might contain paddings.
178+
if (input[idx + 3] == '=') {
179+
if (input[idx + 2] == '=') {
180+
// Tow paddings.
181+
TK_CHECK_OK_OR_RETURN_ERROR(
182+
detail::decode_2_padding(input.substr(idx, 2), output));
183+
} else {
184+
// One padding.
185+
TK_CHECK_OK_OR_RETURN_ERROR(
186+
detail::decode_1_padding(input.substr(idx, 3), output));
187+
}
188+
} else {
189+
// No padding.
190+
TK_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
191+
}
192+
193+
return output;
194+
}
195+
} // namespace base64

include/error.h

+32
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,38 @@ enum class Error : error_code_t {
4545

4646
/// Encode failure.
4747
EncodeFailure = 0x05,
48+
49+
/// Base64 decode failure.
50+
Base64DecodeFailure = 0x06,
4851
};
4952

5053
} // namespace tokenizers
54+
55+
/**
56+
* If cond__ is false, return the specified Error
57+
* from the current function, which must be of return type
58+
* tokenizers::Error.
59+
* TODO: Add logging support
60+
* @param[in] cond__ The condition to be checked, asserted as true.
61+
* @param[in] error__ Error enum value to return without the `Error::` prefix,
62+
* like `InvalidArgument`.
63+
*/
64+
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__) \
65+
{ \
66+
if (!(cond__)) { \
67+
return ::tokenizers::Error::error__; \
68+
} \
69+
}
70+
71+
/**
72+
* If error__ is not Error::Ok, return the specified Error
73+
* TODO: Add logging support
74+
* @param[in] error__ Error enum value to return without the `Error::` prefix,
75+
* like `InvalidArgument`.
76+
*/
77+
#define TK_CHECK_OK_OR_RETURN_ERROR(error__) \
78+
{ \
79+
if (error__ != ::tokenizers::Error::Ok) { \
80+
return error__; \
81+
} \
82+
}

include/tiktoken.h

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Tiktoken header
10+
// Used by OpenAI, adapted from https://github.com/sewenew/tokenizer
11+
#include "re2/re2.h"
12+
#include "tokenizer.h"
13+
#include <cstdint>
14+
15+
#pragma once
16+
17+
using Encoder = std::unordered_map<std::string, uint64_t>;
18+
using Decoder = std::unordered_map<uint64_t, std::string>;
19+
using Re2UPtr = std::unique_ptr<re2::RE2>;
20+
21+
namespace tokenizers {
22+
23+
class Tiktoken : public Tokenizer {
24+
public:
25+
explicit Tiktoken();
26+
~Tiktoken() override;
27+
28+
Error load(const std::string &tokenizer_path) override;
29+
30+
Result<std::vector<uint64_t>> encode(const std::string &input, int8_t bos,
31+
int8_t eos) const override;
32+
33+
Result<std::string> decode(uint64_t prev_token,
34+
uint64_t token) const override;
35+
36+
private:
37+
static inline const Encoder _get_special_tokens(ssize_t num_base_tokens) {
38+
Encoder special_tokens;
39+
special_tokens.emplace("<|begin_of_text|>", num_base_tokens++);
40+
special_tokens.emplace("<|end_of_text|>", num_base_tokens++);
41+
special_tokens.emplace("<|reserved_special_token_0|>", num_base_tokens++);
42+
special_tokens.emplace("<|reserved_special_token_1|>", num_base_tokens++);
43+
special_tokens.emplace("<|reserved_special_token_2|>", num_base_tokens++);
44+
special_tokens.emplace("<|reserved_special_token_3|>", num_base_tokens++);
45+
special_tokens.emplace("<|start_header_id|>", num_base_tokens++);
46+
special_tokens.emplace("<|end_header_id|>", num_base_tokens++);
47+
special_tokens.emplace("<|reserved_special_token_4|>", num_base_tokens++);
48+
special_tokens.emplace("<|eot_id|>", num_base_tokens++);
49+
for (auto i = 5; i < 251; ++i) {
50+
special_tokens.emplace("<|reserved_special_token_" + std::to_string(i) +
51+
"|>",
52+
num_base_tokens++);
53+
}
54+
return special_tokens;
55+
}
56+
57+
template <typename T>
58+
std::pair<std::optional<std::string>, re2::StringPiece>
59+
_split_with_allowed_special_token(re2::StringPiece &input,
60+
const T &allowed_special);
61+
62+
void _encode(re2::StringPiece &input, std::vector<uint64_t> &ret,
63+
uint64_t &last_piece_token_len);
64+
65+
template <typename T>
66+
std::pair<std::vector<uint64_t>, uint64_t>
67+
_encode_with_special_token(const std::string &text, const T &allowed_special);
68+
69+
// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
70+
const std::string _pattern =
71+
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";
72+
Encoder _encoder;
73+
Encoder _special_token_encoder;
74+
Decoder _decoder;
75+
Decoder _special_token_decoder;
76+
77+
Re2UPtr _regex;
78+
Re2UPtr _special_token_regex;
79+
};
80+
81+
} // namespace tokenizers

0 commit comments

Comments
 (0)