Skip to content

Commit c78af8f

Browse files
committed
add_cli_tokenizer
1 parent 0c6f193 commit c78af8f

File tree

3 files changed

+806
-0
lines changed

3 files changed

+806
-0
lines changed

fastdeploy/entrypoints/cli/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def main():
2424
import fastdeploy.entrypoints.cli.benchmark.main
2525
import fastdeploy.entrypoints.cli.openai
2626
import fastdeploy.entrypoints.cli.serve
27+
import fastdeploy.entrypoints.cli.tokenizer
2728
from fastdeploy.utils import FlexibleArgumentParser
2829

2930
CMD_MODULES = [
31+
fastdeploy.entrypoints.cli.tokenizer,
3032
fastdeploy.entrypoints.cli.openai,
3133
fastdeploy.entrypoints.cli.benchmark.main,
3234
fastdeploy.entrypoints.cli.serve,
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import argparse
20+
import json
21+
import typing
22+
from pathlib import Path
23+
24+
from fastdeploy.entrypoints.cli.types import CLISubcommand
25+
from fastdeploy.input.preprocess import InputPreprocessor
26+
27+
if typing.TYPE_CHECKING:
28+
from fastdeploy.utils import FlexibleArgumentParser
29+
30+
31+
class TokenizerSubcommand(CLISubcommand):
32+
"""The `tokenizer` subcommand for the FastDeploy CLI."""
33+
34+
name = "tokenizer"
35+
36+
@staticmethod
37+
def cmd(args: argparse.Namespace) -> None:
38+
main(args)
39+
40+
def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
41+
tokenizer_parser = subparsers.add_parser(
42+
name=self.name,
43+
help="Start the FastDeploy Tokenizer Server.",
44+
description="Start the FastDeploy Tokenizer Server.",
45+
usage="fastdeploy tokenizer [--encode/-e TEXT] [--decode/-d TEXT]",
46+
)
47+
48+
# 添加通用参数
49+
tokenizer_parser.add_argument(
50+
"--model_name_or_path",
51+
"-m",
52+
type=str,
53+
default="baidu/ERNIE-4.5-0.3B-PT",
54+
help="Path to model or model identifier",
55+
)
56+
tokenizer_parser.add_argument("--vocab-size", action="store_true", help="Show vocabulary size")
57+
tokenizer_parser.add_argument("--info", action="store_true", help="Show tokenizer information")
58+
tokenizer_parser.add_argument("--vocab-export", type=str, metavar="FILE", help="Export vocabulary to file")
59+
tokenizer_parser.add_argument("--encode", "-e", default=None, help="Encode text to tokens")
60+
tokenizer_parser.add_argument("--decode", "-d", default=None, help="Decode tokens to text")
61+
62+
return tokenizer_parser
63+
64+
65+
def cmd_init() -> list[CLISubcommand]:
66+
return [TokenizerSubcommand()]
67+
68+
69+
def get_vocab_size(tokenizer) -> int:
70+
"""获取词表大小"""
71+
try:
72+
if hasattr(tokenizer, "vocab_size"):
73+
return tokenizer.vocab_size
74+
elif hasattr(tokenizer, "get_vocab_size"):
75+
return tokenizer.get_vocab_size()
76+
else:
77+
return 100295 # Ernie4_5Tokenizer的固定词表大小
78+
except Exception:
79+
return 0
80+
81+
82+
def get_tokenizer_info(tokenizer) -> dict:
83+
"""获取tokenizer的元信息"""
84+
info = {}
85+
86+
try:
87+
# 基本属性
88+
info["vocab_size"] = get_vocab_size(tokenizer)
89+
90+
# 模型类型和路径
91+
if hasattr(tokenizer, "name_or_path"):
92+
info["model_name"] = tokenizer.name_or_path
93+
94+
# tokenizer类型
95+
info["tokenizer_type"] = type(tokenizer).__name__
96+
97+
# 特殊符号
98+
special_tokens = {}
99+
for attr in ["bos_token", "eos_token", "unk_token", "sep_token", "pad_token", "cls_token", "mask_token"]:
100+
if hasattr(tokenizer, attr):
101+
token = getattr(tokenizer, attr)
102+
if token:
103+
special_tokens[attr] = token
104+
info["special_tokens"] = special_tokens
105+
106+
# 特殊token IDs
107+
special_token_ids = {}
108+
for attr in [
109+
"bos_token_id",
110+
"eos_token_id",
111+
"unk_token_id",
112+
"sep_token_id",
113+
"pad_token_id",
114+
"cls_token_id",
115+
"mask_token_id",
116+
]:
117+
if hasattr(tokenizer, attr):
118+
token_id = getattr(tokenizer, attr)
119+
if token_id is not None:
120+
special_token_ids[attr] = token_id
121+
info["special_token_ids"] = special_token_ids
122+
123+
# 模型最大长度
124+
if hasattr(tokenizer, "model_max_length"):
125+
info["model_max_length"] = tokenizer.model_max_length
126+
127+
except Exception as e:
128+
info["error"] = f"Failed to get tokenizer info: {e}"
129+
130+
return info
131+
132+
133+
def get_vocab_dict(tokenizer) -> dict:
134+
"""获取词表字典"""
135+
try:
136+
if hasattr(tokenizer, "vocab"):
137+
return tokenizer.vocab
138+
elif hasattr(tokenizer, "get_vocab"):
139+
return tokenizer.get_vocab()
140+
elif hasattr(tokenizer, "tokenizer") and hasattr(tokenizer.tokenizer, "vocab"):
141+
return tokenizer.tokenizer.vocab
142+
elif hasattr(tokenizer, "encoder"):
143+
return tokenizer.encoder
144+
else:
145+
return {}
146+
except Exception:
147+
return {}
148+
149+
150+
def export_vocabulary(tokenizer, file_path: str) -> None:
151+
"""导出词表到文件"""
152+
try:
153+
vocab = get_vocab_dict(tokenizer)
154+
if not vocab:
155+
print("Warning: Could not retrieve vocabulary from tokenizer")
156+
return
157+
158+
path = Path(file_path)
159+
path.parent.mkdir(parents=True, exist_ok=True)
160+
161+
# 根据文件扩展名选择格式
162+
if path.suffix.lower() == ".json":
163+
with open(path, "w", encoding="utf-8") as f:
164+
json.dump(vocab, f, ensure_ascii=False, indent=2)
165+
else:
166+
# 默认格式:每行一个token
167+
with open(path, "w", encoding="utf-8") as f:
168+
for token, token_id in sorted(vocab.items(), key=lambda x: x[1]):
169+
# 处理不可打印字符
170+
try:
171+
f.write(f"{token_id}\t{repr(token)}\n")
172+
except:
173+
f.write(f"{token_id}\t<unprintable>\n")
174+
175+
print(f"Vocabulary exported to: {file_path}")
176+
print(f"Total tokens: {len(vocab)}")
177+
178+
except Exception as e:
179+
print(f"Error exporting vocabulary: {e}")
180+
181+
182+
def main(args: argparse.Namespace) -> None:
183+
184+
def print_separator(title=""):
185+
if title:
186+
print(f"\n{'='*50}")
187+
print(f" {title}")
188+
print(f"{'='*50}")
189+
else:
190+
print(f"\n{'='*50}")
191+
192+
# 检查参数
193+
if not any([args.encode, args.decode, args.vocab_size, args.info, args.vocab_export]):
194+
print("请至少指定一个参数:--encode, --decode, --vocab-size, --info, --export-vocab")
195+
return
196+
197+
# 初始化tokenizer
198+
preprocessor = InputPreprocessor(model_name_or_path=args.model_name_or_path)
199+
tokenizer = preprocessor.create_processor().tokenizer
200+
201+
# 执行操作
202+
operations_count = 0
203+
204+
if args.encode:
205+
print_separator("ENCODING")
206+
print(f"Input text: {args.encode}")
207+
encoded_text = tokenizer.encode(args.encode)
208+
print(f"Encoded tokens: {encoded_text}")
209+
operations_count += 1
210+
211+
if args.decode:
212+
print_separator("DECODING")
213+
print(f"Input tokens: {args.decode}")
214+
try:
215+
if isinstance(args.decode, str):
216+
if args.decode.startswith("[") and args.decode.endswith("]"):
217+
tokens = eval(args.decode)
218+
else:
219+
tokens = [int(x.strip()) for x in args.decode.split(",")]
220+
else:
221+
tokens = args.decode
222+
223+
decoded_text = tokenizer.decode(tokens)
224+
print(f"Decoded text: {decoded_text}")
225+
except Exception as e:
226+
print(f"Error decoding tokens: {e}")
227+
operations_count += 1
228+
229+
if args.vocab_size:
230+
print_separator("VOCABULARY SIZE")
231+
print(f"Vocabulary size: {get_vocab_size(tokenizer)}")
232+
operations_count += 1
233+
234+
if args.info:
235+
print_separator("TOKENIZER INFO")
236+
print(json.dumps(get_tokenizer_info(tokenizer), indent=2))
237+
operations_count += 1
238+
239+
if args.vocab_export:
240+
print_separator("EXPORT VOCABULARY")
241+
export_vocabulary(tokenizer, args.vocab_export)
242+
operations_count += 1
243+
244+
print_separator()
245+
print(f"Completed {operations_count} operation(s)")

0 commit comments

Comments
 (0)