Skip to content

Commit 26ea0c3

Browse files
authored
Example doc update (#437)
* Add a new PyTorch example to show embedding. Update pybind document. * Change the function of embedding to take word dictionary.
1 parent f04cc55 commit 26ea0c3

File tree

4 files changed

+141
-24
lines changed

4 files changed

+141
-24
lines changed

demo/pytorch/pytorch_word2vec.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# http://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html?highlight=embedding
2+
# The following tutorial is from the PyTorch site.
3+
# =======================================================================
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch.optim as optim
9+
10+
# Import VisualDL
11+
from visualdl import LogWriter
12+
13+
torch.manual_seed(1)
14+
CONTEXT_SIZE = 2
15+
EMBEDDING_DIM = 10
16+
# We will use Shakespeare Sonnet 2
17+
test_sentence = """When forty winters shall besiege thy brow,
18+
And dig deep trenches in thy beauty's field,
19+
Thy youth's proud livery so gazed on now,
20+
Will be a totter'd weed of small worth held:
21+
Then being asked, where all thy beauty lies,
22+
Where all the treasure of thy lusty days;
23+
To say, within thine own deep sunken eyes,
24+
Were an all-eating shame, and thriftless praise.
25+
How much more praise deserv'd thy beauty's use,
26+
If thou couldst answer 'This fair child of mine
27+
Shall sum my count, and make my old excuse,'
28+
Proving his beauty by succession thine!
29+
This were to be new made when thou art old,
30+
And see thy blood warm when thou feel'st it cold.""".split()
31+
# we should tokenize the input, but we will ignore that for now
32+
# build a list of tuples. Each tuple is ([ word_i-2, word_i-1 ], target word)
33+
trigrams = [([test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2])
34+
for i in range(len(test_sentence) - 2)]
35+
# print the first 3, just so you can see what they look like
36+
print(trigrams[:3])
37+
38+
vocab = set(test_sentence)
39+
word_to_ix = {word: i for i, word in enumerate(vocab)}
40+
41+
42+
class NGramLanguageModeler(nn.Module):
43+
def __init__(self, vocab_size, embedding_dim, context_size):
44+
super(NGramLanguageModeler, self).__init__()
45+
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
46+
self.linear1 = nn.Linear(context_size * embedding_dim, 128)
47+
self.linear2 = nn.Linear(128, vocab_size)
48+
49+
def forward(self, inputs):
50+
embeds = self.embeddings(inputs).view((1, -1))
51+
out = F.relu(self.linear1(embeds))
52+
out = self.linear2(out)
53+
log_probs = F.log_softmax(out, dim=1)
54+
return log_probs
55+
56+
57+
losses = []
58+
loss_function = nn.NLLLoss()
59+
model = NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)
60+
optimizer = optim.SGD(model.parameters(), lr=0.001)
61+
62+
for epoch in range(10):
63+
total_loss = torch.Tensor([0])
64+
for context, target in trigrams:
65+
66+
# Step 1. Prepare the inputs to be passed to the model (i.e, turn the words
67+
# into integer indices and wrap them in variables)
68+
context_idxs = torch.tensor(
69+
[word_to_ix[w] for w in context], dtype=torch.long)
70+
71+
# Step 2. Recall that torch *accumulates* gradients. Before passing in a
72+
# new instance, you need to zero out the gradients from the old
73+
# instance
74+
model.zero_grad()
75+
76+
# Step 3. Run the forward pass, getting log probabilities over next
77+
# words
78+
log_probs = model(context_idxs)
79+
80+
# Step 4. Compute your loss function. (Again, Torch wants the target
81+
# word wrapped in a variable)
82+
loss = loss_function(
83+
log_probs, torch.tensor([word_to_ix[target]], dtype=torch.long))
84+
85+
# Step 5. Do the backward pass and update the gradient
86+
loss.backward()
87+
optimizer.step()
88+
89+
# Get the Python number from a 1-element Tensor by calling tensor.item()
90+
total_loss += loss.item()
91+
losses.append(total_loss)
92+
print(losses) # The loss decreased every iteration over the training data!
93+
94+
# VisualDL setup
95+
logw = LogWriter("./embedding_log", sync_cycle=10000)
96+
with logw.mode('train') as logger:
97+
embedding = logger.embedding()
98+
99+
embeddings_list = model.embeddings.weight.data.numpy() # convert to numpy array
100+
101+
# VisualDL embedding log writer takes two parameters
102+
# The first parameter is embedding list. The type is list[list[float]]
103+
# The second parameter is word_dict. The type is dictionary<string, int>.
104+
embedding.add_embeddings_with_word_dict(embeddings_list, word_to_ix)

visualdl/logic/pybind.cc

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ PYBIND11_MODULE(core, m) {
4242
.. autoclass:: ImageWriter
4343
:members:
4444
45+
.. autoclass:: TextWriter
46+
:members:
47+
48+
.. autoclass:: AudioWriter
49+
:members:
50+
4551
)pbdoc";
4652

4753
py::class_<vs::LogReader>(m, "LogReader")
@@ -240,7 +246,7 @@ PYBIND11_MODULE(core, m) {
240246
Add a record with the step and text value.
241247
242248
:param step: Current step value
243-
:type index: integer
249+
:type step: integer
244250
:param text: Text record
245251
:type text: basestring
246252
)pbdoc");
@@ -257,15 +263,25 @@ PYBIND11_MODULE(core, m) {
257263
PyBind class. Must instantiate through the LogWriter.
258264
)pbdoc")
259265
.def("set_caption", &cp::Embedding::SetCaption)
260-
.def(
261-
"add_embeddings_with_word_list"
262-
R"pbdoc(
263-
Add embedding record. Each run can only store one embedding data.
264-
265-
:param embedding: hot vector of embedding words
266-
:type embedding: list
267-
)pbdoc",
268-
&cp::Embedding::AddEmbeddingsWithWordList);
266+
.def("add_embeddings_with_word_dict",
267+
&cp::Embedding::AddEmbeddingsWithWordDict,
268+
R"pbdoc(
269+
Add the embedding record. Each run can only store one embedding data. **embeddings** and **word_dict** should be
270+
the same length. The **word_dict** is used to find the word embedding index in **embeddings**::
271+
272+
embeddings = [[-1.5246837, -0.7505612, -0.65406495, -1.610278],
273+
[-0.781105, -0.24952792, -0.22178008, 1.6906816]]
274+
275+
word_dict = {"Apple" : 0, "Orange": 1}
276+
277+
Shows that ``"Apple"`` is embedded to ``[-1.5246837, -0.7505612, -0.65406495, -1.610278]`` and
278+
``"Orange"`` is embedded to ``[-0.781105, -0.24952792, -0.22178008, 1.6906816]``
279+
280+
:param embeddings: list of word embeddings
281+
:type embeddings: list
282+
:param word_dict: The mapping from words to indices.
283+
:type word_dict: dictionary
284+
)pbdoc");
269285

270286
py::class_<cp::EmbeddingReader>(m, "EmbeddingReader")
271287
.def("get_all_labels", &cp::EmbeddingReader::get_all_labels)

visualdl/logic/sdk.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,17 +350,19 @@ size_t TextReader::size() const { return reader_.total_records(); }
350350
/*
351351
* Embedding functions
352352
*/
353-
void Embedding::AddEmbeddingsWithWordList(
353+
void Embedding::AddEmbeddingsWithWordDict(
354354
const std::vector<std::vector<float>>& word_embeddings,
355-
std::vector<std::string>& labels) {
356-
for (int i = 0; i < word_embeddings.size(); i++) {
357-
AddEmbedding(i, word_embeddings[i], labels[i]);
355+
std::map<std::string, int>& word_dict) {
356+
for (auto& word_index_pair : word_dict) {
357+
AddEmbedding(word_index_pair.second,
358+
word_embeddings[word_index_pair.second],
359+
word_index_pair.first);
358360
}
359361
}
360362

361363
void Embedding::AddEmbedding(int item_id,
362364
const std::vector<float>& one_hot_vector,
363-
std::string& label) {
365+
const std::string& label) {
364366
auto record = tablet_.AddRecord();
365367
record.SetId(item_id);
366368
time_t time = std::time(nullptr);

visualdl/logic/sdk.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -337,19 +337,14 @@ struct Embedding {
337337
void SetCaption(const std::string cap) {
338338
tablet_.SetCaptions(std::vector<std::string>({cap}));
339339
}
340-
341-
// Add all word vectors along with all labels
342-
// The index of labels should match with the index of word_embeddings
343-
// EX: ["Apple", "Orange"] means the first item in word_embeddings represents
344-
// "Apple"
345-
void AddEmbeddingsWithWordList(
340+
void AddEmbeddingsWithWordDict(
346341
const std::vector<std::vector<float>>& word_embeddings,
347-
std::vector<std::string>& labels);
348-
// TODO: Create another function that takes 'word_embeddings' and 'word_dict'
342+
std::map<std::string, int>& word_dict);
343+
349344
private:
350345
void AddEmbedding(int item_id,
351346
const std::vector<float>& one_hot_vector,
352-
std::string& label);
347+
const std::string& label);
353348

354349
Tablet tablet_;
355350
};

0 commit comments

Comments
 (0)