Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7ab50af

Browse files
authored
splitting registration and refactoring vocab.py module (#1352)
1 parent aa75fe0 commit 7ab50af

File tree

12 files changed

+493
-455
lines changed

12 files changed

+493
-455
lines changed

.circleci/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ dependencies:
1818
- sphinx
1919
- sphinx-rtd-theme
2020
- tqdm
21+
- expecttest
2122
- https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0
2223
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0

.circleci/unittest/windows/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ dependencies:
2121
- tqdm
2222
- certifi
2323
- future
24+
- expecttest
2425
- https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0
2526
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0
Lines changed: 4 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <torch/script.h>
1010
#include <vectors.h> // @manual
1111
#include <vocab.h> // @manual
12+
#include <vocab_factory.h>
13+
1214
namespace torchtext {
1315

1416
namespace py = pybind11;
@@ -155,126 +157,8 @@ PYBIND11_MODULE(_torchtext, m) {
155157
&_load_token_and_vectors_from_file);
156158
m.def("_load_vocab_from_file", &_load_vocab_from_file);
157159
m.def("_build_vocab_from_text_file", &build_vocab_from_text_file);
158-
m.def("_build_vocab_from_text_file_using_python_tokenizer", &_build_vocab_from_text_file_using_python_tokenizer);
159-
}
160-
161-
TORCH_LIBRARY_FRAGMENT(torchtext, m) {
162-
m.class_<Regex>("Regex")
163-
.def(torch::init<std::string>())
164-
.def("Sub", &Regex::Sub)
165-
.def_pickle(
166-
// __getstate__
167-
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
168-
return _serialize_regex(self);
169-
},
170-
// __setstate__
171-
[](std::string state) -> c10::intrusive_ptr<Regex> {
172-
return _deserialize_regex(std::move(state));
173-
});
174-
175-
m.class_<RegexTokenizer>("RegexTokenizer")
176-
.def(torch::init<std::vector<std::string>, std::vector<std::string>,
177-
bool>())
178-
.def("forward", &RegexTokenizer::forward)
179-
.def_pickle(
180-
// __getstate__
181-
[](const c10::intrusive_ptr<RegexTokenizer> &self)
182-
-> RegexTokenizerStates {
183-
return _serialize_regex_tokenizer(self);
184-
},
185-
// __setstate__
186-
[](RegexTokenizerStates states)
187-
-> c10::intrusive_ptr<RegexTokenizer> {
188-
return _deserialize_regex_tokenizer(std::move(states));
189-
});
190-
191-
m.class_<SentencePiece>("SentencePiece")
192-
.def(torch::init<std::string>())
193-
.def("Encode", &SentencePiece::Encode)
194-
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
195-
.def("DecodeIds", &SentencePiece::DecodeIds)
196-
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
197-
.def("DecodePieces", &SentencePiece::DecodePieces)
198-
.def("GetPieceSize", &SentencePiece::GetPieceSize)
199-
.def("unk_id", &SentencePiece::unk_id)
200-
.def("PieceToId", &SentencePiece::PieceToId)
201-
.def("IdToPiece", &SentencePiece::IdToPiece)
202-
.def_pickle(
203-
// The underlying content of SentencePiece contains byte string,
204-
// and returing it as std::string cause UTF8 decoding error.
205-
// Since TorchScript does not support byte string, we use byte Tensor
206-
// to pass around the data.
207-
// __getstate__
208-
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
209-
auto *data =
210-
static_cast<void *>(const_cast<char *>(self->content_.data()));
211-
auto numel = static_cast<int64_t>(self->content_.size());
212-
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
213-
},
214-
// __setstate__
215-
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
216-
auto *data = static_cast<char *>(state.data_ptr());
217-
auto numel = state.size(0);
218-
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
219-
});
220-
221-
m.class_<Vectors>("Vectors")
222-
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>,
223-
torch::Tensor, torch::Tensor>())
224-
.def("__getitem__", &Vectors::__getitem__)
225-
.def("lookup_vectors", &Vectors::lookup_vectors)
226-
.def("__setitem__", &Vectors::__setitem__)
227-
.def("__len__", &Vectors::__len__)
228-
.def_pickle(
229-
// __getstate__
230-
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
231-
return _serialize_vectors(self);
232-
},
233-
// __setstate__
234-
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
235-
return _deserialize_vectors(states);
236-
});
237-
238-
m.class_<Vocab>("Vocab")
239-
.def(torch::init<StringList, c10::optional<int64_t>>())
240-
.def("__contains__",
241-
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
242-
-> bool { return self->__contains__(c10::string_view{item}); })
243-
.def("__getitem__",
244-
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
245-
-> int64_t { return self->__getitem__(c10::string_view{item}); })
246-
.def("insert_token", &Vocab::insert_token)
247-
.def("__len__", &Vocab::__len__)
248-
.def("set_default_index", &Vocab::set_default_index)
249-
.def("get_default_index", &Vocab::get_default_index)
250-
.def("append_token", &Vocab::append_token)
251-
.def("lookup_token", &Vocab::lookup_token)
252-
.def("lookup_tokens", &Vocab::lookup_tokens)
253-
.def("lookup_indices",
254-
[](const c10::intrusive_ptr<Vocab> &self,
255-
const std::vector<std::string> &items) {
256-
std::vector<int64_t> indices(items.size());
257-
int64_t counter = 0;
258-
for (const auto &item : items) {
259-
indices[counter++] = self->__getitem__(c10::string_view{item});
260-
}
261-
return indices;
262-
})
263-
.def("get_stoi", &Vocab::get_stoi)
264-
.def("get_itos", &Vocab::get_itos)
265-
.def_pickle(
266-
// __getstate__
267-
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
268-
return _serialize_vocab(self);
269-
},
270-
// __setstate__
271-
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
272-
return _deserialize_vocab(states);
273-
});
274-
275-
m.def("torchtext::generate_sp_model", &generate_sp_model);
276-
m.def("torchtext::load_sp_model", &load_sp_model);
277-
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
160+
m.def("_build_vocab_from_text_file_using_python_tokenizer",
161+
&_build_vocab_from_text_file_using_python_tokenizer);
278162
}
279163

280164
} // namespace torchtext
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include <iostream>
2+
#include <regex.h>
3+
#include <regex_tokenizer.h> // @manual
4+
#include <sentencepiece.h> // @manual
5+
#include <torch/script.h>
6+
#include <vectors.h> // @manual
7+
#include <vocab.h> // @manual
8+
namespace torchtext {
9+
10+
TORCH_LIBRARY_FRAGMENT(torchtext, m) {
11+
m.class_<Regex>("Regex")
12+
.def(torch::init<std::string>())
13+
.def("Sub", &Regex::Sub)
14+
.def_pickle(
15+
// __getstate__
16+
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
17+
return _serialize_regex(self);
18+
},
19+
// __setstate__
20+
[](std::string state) -> c10::intrusive_ptr<Regex> {
21+
return _deserialize_regex(std::move(state));
22+
});
23+
24+
m.class_<RegexTokenizer>("RegexTokenizer")
25+
.def(torch::init<std::vector<std::string>, std::vector<std::string>,
26+
bool>())
27+
.def("forward", &RegexTokenizer::forward)
28+
.def_pickle(
29+
// __getstate__
30+
[](const c10::intrusive_ptr<RegexTokenizer> &self)
31+
-> RegexTokenizerStates {
32+
return _serialize_regex_tokenizer(self);
33+
},
34+
// __setstate__
35+
[](RegexTokenizerStates states)
36+
-> c10::intrusive_ptr<RegexTokenizer> {
37+
return _deserialize_regex_tokenizer(std::move(states));
38+
});
39+
40+
m.class_<SentencePiece>("SentencePiece")
41+
.def(torch::init<std::string>())
42+
.def("Encode", &SentencePiece::Encode)
43+
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
44+
.def("DecodeIds", &SentencePiece::DecodeIds)
45+
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
46+
.def("DecodePieces", &SentencePiece::DecodePieces)
47+
.def("GetPieceSize", &SentencePiece::GetPieceSize)
48+
.def("unk_id", &SentencePiece::unk_id)
49+
.def("PieceToId", &SentencePiece::PieceToId)
50+
.def("IdToPiece", &SentencePiece::IdToPiece)
51+
.def_pickle(
52+
// The underlying content of SentencePiece contains byte string,
53+
// and returing it as std::string cause UTF8 decoding error.
54+
// Since TorchScript does not support byte string, we use byte Tensor
55+
// to pass around the data.
56+
// __getstate__
57+
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
58+
auto *data =
59+
static_cast<void *>(const_cast<char *>(self->content_.data()));
60+
auto numel = static_cast<int64_t>(self->content_.size());
61+
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
62+
},
63+
// __setstate__
64+
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
65+
auto *data = static_cast<char *>(state.data_ptr());
66+
auto numel = state.size(0);
67+
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
68+
});
69+
70+
m.class_<Vectors>("Vectors")
71+
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>,
72+
torch::Tensor, torch::Tensor>())
73+
.def("__getitem__", &Vectors::__getitem__)
74+
.def("lookup_vectors", &Vectors::lookup_vectors)
75+
.def("__setitem__", &Vectors::__setitem__)
76+
.def("__len__", &Vectors::__len__)
77+
.def_pickle(
78+
// __getstate__
79+
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
80+
return _serialize_vectors(self);
81+
},
82+
// __setstate__
83+
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
84+
return _deserialize_vectors(states);
85+
});
86+
87+
m.class_<Vocab>("Vocab")
88+
.def(torch::init<StringList, c10::optional<int64_t>>())
89+
.def("__contains__",
90+
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
91+
-> bool { return self->__contains__(c10::string_view{item}); })
92+
.def("__getitem__",
93+
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
94+
-> int64_t { return self->__getitem__(c10::string_view{item}); })
95+
.def("insert_token", &Vocab::insert_token)
96+
.def("__len__", &Vocab::__len__)
97+
.def("set_default_index", &Vocab::set_default_index)
98+
.def("get_default_index", &Vocab::get_default_index)
99+
.def("append_token", &Vocab::append_token)
100+
.def("lookup_token", &Vocab::lookup_token)
101+
.def("lookup_tokens", &Vocab::lookup_tokens)
102+
.def("lookup_indices",
103+
[](const c10::intrusive_ptr<Vocab> &self,
104+
const std::vector<std::string> &items) {
105+
std::vector<int64_t> indices(items.size());
106+
int64_t counter = 0;
107+
for (const auto &item : items) {
108+
indices[counter++] = self->__getitem__(c10::string_view{item});
109+
}
110+
return indices;
111+
})
112+
.def("get_stoi", &Vocab::get_stoi)
113+
.def("get_itos", &Vocab::get_itos)
114+
.def_pickle(
115+
// __getstate__
116+
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
117+
return _serialize_vocab(self);
118+
},
119+
// __setstate__
120+
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
121+
return _deserialize_vocab(states);
122+
});
123+
124+
m.def("torchtext::generate_sp_model", &generate_sp_model);
125+
m.def("torchtext::load_sp_model", &load_sp_model);
126+
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
127+
}
128+
129+
} // namespace torchtext

torchtext/csrc/vocab.cpp

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,6 @@ void parse_raw_text_file_chunk(const std::string &file_path, size_t offset,
191191
}
192192
}
193193

194-
// sorting using a custom object
195-
struct CompareTokens {
196-
bool operator()(const std::pair<std::string, int64_t> &a,
197-
const std::pair<std::string, int64_t> &b) {
198-
if (a.second == b.second) {
199-
return a.first < b.first;
200-
}
201-
return a.second > b.second;
202-
}
203-
};
204-
205194
StringList
206195
_concat_tokens(std::vector<std::shared_ptr<IndexDict>> chunk_counters,
207196
const int64_t min_freq, const int64_t num_lines,
@@ -345,54 +334,6 @@ Vocab _build_vocab_from_text_file(const std::string &file_path,
345334
return Vocab(std::move(tokens));
346335
}
347336

348-
Vocab _build_vocab_from_text_file_using_python_tokenizer(
349-
const std::string &file_path, const int64_t min_freq,
350-
py::object tokenizer) {
351-
// find number of lines
352-
int64_t num_lines = _infer_lines(file_path);
353-
// Read text from file and add tokens
354-
std::ifstream fin(file_path, std::ios::in);
355-
TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path);
356-
357-
IndexDict counter;
358-
std::string line;
359-
for (int64_t i = 0; i < num_lines; i++) {
360-
std::getline(fin, line);
361-
std::vector<std::string> token_list =
362-
tokenizer(line).cast<std::vector<std::string>>();
363-
364-
for (size_t i = 0; i < token_list.size(); i++) {
365-
std::string token = token_list[i];
366-
367-
if (counter.find(token) == counter.end()) {
368-
counter[token] = 1;
369-
} else {
370-
counter[token] += 1;
371-
}
372-
}
373-
}
374-
375-
// create tokens-frequency pairs
376-
std::vector<std::pair<std::string, int64_t>> token_freq_pairs;
377-
for (const auto &item : counter) {
378-
if (item.second >= min_freq) {
379-
token_freq_pairs.push_back(item);
380-
}
381-
}
382-
383-
// sort tokens by frequency
384-
CompareTokens compare_tokens;
385-
std::sort(token_freq_pairs.begin(), token_freq_pairs.end(), compare_tokens);
386-
387-
// Create final list of tokens
388-
StringList tokens;
389-
for (const auto &token_freq_pair : token_freq_pairs) {
390-
tokens.push_back(token_freq_pair.first);
391-
}
392-
393-
return Vocab(std::move(tokens));
394-
}
395-
396337
VocabStates _serialize_vocab(const c10::intrusive_ptr<Vocab> &self) {
397338
std::vector<int64_t> integers;
398339
StringList strings = self->itos_;

torchtext/csrc/vocab.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1+
#pragma once
12
#include <algorithm>
23
#include <c10/util/string_view.h>
3-
#include <pybind11/pybind11.h>
44
#include <torch/script.h>
55

6-
namespace py = pybind11;
7-
86
namespace torchtext {
97

108
typedef std::vector<std::string> StringList;
@@ -14,6 +12,19 @@ typedef std::tuple<std::string, std::vector<int64_t>, std::vector<std::string>,
1412
std::vector<torch::Tensor>>
1513
VocabStates;
1614

15+
// sorting using a custom object
16+
struct CompareTokens {
17+
bool operator()(const std::pair<std::string, int64_t> &a,
18+
const std::pair<std::string, int64_t> &b) {
19+
if (a.second == b.second) {
20+
return a.first < b.first;
21+
}
22+
return a.second > b.second;
23+
}
24+
};
25+
26+
int64_t _infer_lines(const std::string &file_path);
27+
1728
struct Vocab : torch::CustomClassHolder {
1829
static const int32_t MAX_VOCAB_SIZE = 30000000;
1930
int64_t unk_index_;
@@ -79,7 +90,4 @@ Vocab _build_vocab_from_text_file(const std::string &file_path,
7990
const int64_t min_freq,
8091
const int64_t num_cpus,
8192
torch::jit::script::Module tokenizer);
82-
Vocab _build_vocab_from_text_file_using_python_tokenizer(
83-
const std::string &file_path, const int64_t min_freq, py::object tokenizer);
84-
8593
} // namespace torchtext

0 commit comments

Comments
 (0)