Skip to content

Commit c3fd2c2

Browse files
authored
Merge pull request #7002 from qingqing01/imdb_data
Speed data reader for IMDB dataset.
2 parents f839154 + eb8edeb commit c3fd2c2

File tree

1 file changed

+13
-40
lines changed

1 file changed

+13
-40
lines changed

python/paddle/v2/dataset/imdb.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323
import paddle.v2.dataset.common
2424
import collections
2525
import tarfile
26-
import Queue
2726
import re
2827
import string
29-
import threading
28+
import random
3029

3130
__all__ = ['build_dict', 'train', 'test', 'convert']
3231

@@ -74,47 +73,21 @@ def build_dict(pattern, cutoff):
7473
return word_idx
7574

7675

77-
def reader_creator(pos_pattern, neg_pattern, word_idx, buffer_size):
76+
def reader_creator(pos_pattern, neg_pattern, word_idx):
7877
UNK = word_idx['<unk>']
78+
INS = []
7979

80-
qs = [Queue.Queue(maxsize=buffer_size), Queue.Queue(maxsize=buffer_size)]
81-
82-
def load(pattern, queue):
80+
def load(pattern, out, label):
8381
for doc in tokenize(pattern):
84-
queue.put(doc)
85-
queue.put(None)
82+
out.append(([word_idx.get(w, UNK) for w in doc], label))
83+
84+
load(pos_pattern, INS, 0)
85+
load(neg_pattern, INS, 1)
86+
random.shuffle(INS)
8687

8788
def reader():
88-
# Creates two threads that loads positive and negative samples
89-
# into qs.
90-
t0 = threading.Thread(
91-
target=load, args=(
92-
pos_pattern,
93-
qs[0], ))
94-
t0.daemon = True
95-
t0.start()
96-
97-
t1 = threading.Thread(
98-
target=load, args=(
99-
neg_pattern,
100-
qs[1], ))
101-
t1.daemon = True
102-
t1.start()
103-
104-
# Read alternatively from qs[0] and qs[1].
105-
i = 0
106-
doc = qs[i].get()
107-
while doc != None:
108-
yield [word_idx.get(w, UNK) for w in doc], i % 2
109-
i += 1
110-
doc = qs[i % 2].get()
111-
112-
# If any queue is empty, reads from the other queue.
113-
i += 1
114-
doc = qs[i % 2].get()
115-
while doc != None:
116-
yield [word_idx.get(w, UNK) for w in doc], i % 2
117-
doc = qs[i % 2].get()
89+
for doc, label in INS:
90+
yield doc, label
11891

11992
return reader
12093

@@ -133,7 +106,7 @@ def train(word_idx):
133106
"""
134107
return reader_creator(
135108
re.compile("aclImdb/train/pos/.*\.txt$"),
136-
re.compile("aclImdb/train/neg/.*\.txt$"), word_idx, 1000)
109+
re.compile("aclImdb/train/neg/.*\.txt$"), word_idx)
137110

138111

139112
def test(word_idx):
@@ -150,7 +123,7 @@ def test(word_idx):
150123
"""
151124
return reader_creator(
152125
re.compile("aclImdb/test/pos/.*\.txt$"),
153-
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000)
126+
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx)
154127

155128

156129
def word_dict():

0 commit comments

Comments
 (0)